# Run everything in this section to define functions and imports. 

In [None]:
# Imports. 
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import os
import numpy as np
import seaborn as sns
import csv
import matplotlib.pyplot as plt
import json
from collections import Counter
from nltk.corpus import wordnet as wn
from tqdm import tqdm

### Define the paths below for code to run. 

In [None]:
snare_path = '/home/rcorona/obj_part_lang/snare-master/amt/folds_adversarial'
metadata_path = './data/metadata.csv'
categories_path = './data/categories.synset.csv'
lfn_feat_dir = './data/lfn_feats'
clip_feat_dir = '/home/rcorona/dev/snare-master/data/shapenet-clipViT32-frames/'
pixelnerf_feat_dir = '/home/rcorona/2022/lang_nerf/vlg/snare-master/data/pixelnerf_custom_feats'
legoformer_feat_dir = '/home/rcorona/2022/lang_nerf/vlg/snare-master/data/legoformer_multiview_feats/'

In [None]:
def get_snare_objs(): 
    """
    Get all ShapeNetSem object IDs for objects used in SNARE. 
    """
    train = json.load(open(os.path.join(snare_path, 'train.json')))
    val = json.load(open(os.path.join(snare_path, 'val.json')))
    test = json.load(open(os.path.join(snare_path, 'test.json')))

    train_objs = set()
    val_objs = set()
    test_objs = set()

    # Comb through snare files to collect unique set of ShapeNet objects. 
    snare_objs = set()

    for obj_set, split in [(train_objs, train), (val_objs, val), (test_objs, test)]:
        for datapoint in split: 
            for obj in datapoint['objects']:
                obj_set.add(obj)

    all_objs = train_objs | val_objs | test_objs

    return list(all_objs)

In [None]:
def get_plot_labels(objs):
    """
    Generate category labels for TSNE plot. 
    """
    
    # Load all metadata for objects. 
    with open(metadata_path, 'r') as csvfile: 
        metadata = [row for row in csv.reader(csvfile)]
        
    # Get index of each object in metdata. 
    obj2synset = {m[0].replace('wss.', '').strip(): m[2].strip() for m in metadata}
        
    # Load synset word mappings. 
    with open(categories_path, 'r') as csvfile: 
        mappings = [row for row in csv.reader(csvfile)]
        
        # Mapping from synset to word. 
        s2w = {r[2].strip(): r[3].split(',')[0].strip() for r in mappings[1:]}    


    # Get set of all synsets. 
    synset_codes = set([r[2].strip() for r in metadata[2:]])    
    synsets = []

    for s in synset_codes: 
        try: 
            synset = s2w[s]
        except: 
            synset = 'None'
            
        synsets.append(synset)
        

    ## Get 10 most common object categories and filter out everything else.  
    snare_synsets = []
    counts = Counter()

    # Count object categories. 
    for obj in objs:
        synset = obj2synset[obj]

        # Only count those with label. 
        if synset in s2w:
            word = s2w[synset]
            snare_synsets.append(word)
            
            # Update word count.
            if not word == '': 
                counts[word] += 1
        else: 
            snare_synsets.append(None)
            
    return snare_synsets, counts

In [None]:
def load_lfn_features(objs):
    """
    Given list of object IDs, load features for each object under LFN.  
    """
    # Load all features and name order. 
    feats = []
    
    for obj in objs:
        
        # Load feature. 
        path = os.path.join(lfn_feat_dir, '{}.npy'.format(obj))
        feat = feats.append(np.load(path))
        
    return feats

In [None]:
def load_clip_features(objs):
    """
    Given list of object IDs, load features for each object under CLIP.  
    """
    # Load all features and name order. 
    feats = []

    for obj in objs:
        
        # Load image features for object. 
        path = os.path.join(clip_feat_dir, '{}.npy'.format(obj))
        
        # Load features for input views and take mean. 
        feat = np.mean(np.load(path)[6:], axis=0)
        feats.append(feat)
        
    return feats

In [None]:
def load_legoformer_features(objs):
    """
    Given list of object IDs, load features for each object under LegoFormer.  
    """
    # Load all features and name order. 
    feats = []

    for obj in objs:
        
        # Load image features for object. 
        path = os.path.join(legoformer_feat_dir, '{}.npy'.format(obj))
        
        # Load and collapse dimension to consider as single feature. 
        feats.append(np.reshape(np.load(path), -1))
        
    return feats

In [None]:
def load_pixelnerf_features(objs):
    """
    Given list of object IDs, load features for each object under PixelNeRF.  
    """
    # Load all features and name order. 
    feats = []

    print('Loading PixelNeRF features...')

    for obj in tqdm(objs):
        
        # Load image features for object. 
        path = os.path.join(pixelnerf_feat_dir, '{}.npy'.format(obj))
        
        # Load features for input views and take mean. 
        feat = np.reshape(np.load(path), (8, 512, -1))
        feat = np.mean(np.mean(feat, axis=0), axis=-1)
        feats.append(feat)
        
    return feats

In [None]:
    def filter_by_top_k(counts, feats, objs, snare_synsets):
    # Get 10 most common categories. 
    top10 = {t[0]: t[1] for t in counts.most_common(10)}

    # Only keep objects in top-10 categories. 
    final_feats = []
    final_labels = []

    assert len(objs) == len(feats) and len(objs) == len(snare_synsets)

    for i in range(len(objs)):
        synset = snare_synsets[i]
        
        if synset in top10: 
            final_feats.append(feats[i])
            final_labels.append(synset)
            
    # Create numpy array of features. 
    final_feats = np.stack(final_feats)
    print('Final feature shape: {}'.format(final_feats.shape))
    
    return final_feats, final_labels

In [None]:
# Reduce dimensionality with PCA. 
def gen_tsne_feats(final_feats):
    pca_feats = PCA(n_components=3).fit_transform(final_feats)
    print('PCA Feat shape: {}'.format(pca_feats.shape))

    # Compute TSNE features.
    tsne_feats = TSNE(n_components=2, random_state=0).fit_transform(pca_feats)
    print('TSNE Feat shape: {}'.format(tsne_feats.shape))
    
    return tsne_feats

In [None]:
def gen_plot(tsne_feats, final_labels, title):
    import seaborn as sns
    import matplotlib.pyplot as plt

    # Plot the TSNE embeddings. 
    sns.scatterplot(x=tsne_feats[:,0], y=tsne_feats[:,1], hue=final_labels, palette=sns.color_palette("hls", 10)).set(title=title)
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

In [None]:
# Generate list of all ShapeNetSem objects used in SNARE. 
snare_objs = get_snare_objs()

# Generate category labels for TSNE plot. 
synsets, counts = get_plot_labels(snare_objs)

# LFN Feature Visualization

In [None]:
# Load LFN features. 
lfn_feats = load_lfn_features(snare_objs)

# Filter features by top 10 most occurring categories. 
final_feats, final_labels = filter_by_top_k(counts, lfn_feats, snare_objs, synsets)

# Generate TSNE features. 
tsne_feats = gen_tsne_feats(final_feats)

# Visualize them. 
gen_plot(tsne_feats, final_labels, title='LFN TSNE Plot')

# CLIP Feature Visualization

In [None]:
# Load LFN features. 
clip_feats = load_clip_features(snare_objs)

# Filter features by top 10 most occurring categories. 
final_feats, final_labels = filter_by_top_k(counts, clip_feats, snare_objs, synsets)

# Generate TSNE features. 
tsne_feats = gen_tsne_feats(final_feats)

# Visualize them. 
gen_plot(tsne_feats, final_labels, title='CLIP TSNE Plot')

# PixelNeRF Feature Visualization

In [None]:
# Load Pixelnerf features. 
pixelnerf_feats = load_pixelnerf_features(snare_objs)

# Filter features by top 10 most occurring categories. 
final_feats, final_labels = filter_by_top_k(counts, pixelnerf_feats, snare_objs, synsets)

# Generate TSNE features. 
tsne_feats = gen_tsne_feats(final_feats)

# Visualize them. 
gen_plot(tsne_feats, final_labels, title='PixelNeRF TSNE Plot')

# VLG (LegoFormer) Feature Visualization

In [None]:
# Load LegoFormer features. 
legoformer_feats = load_legoformer_features(snare_objs)

# Filter features by top 10 most occurring categories. 
final_feats, final_labels = filter_by_top_k(counts, legoformer_feats, snare_objs, synsets)

# Generate TSNE features. 
tsne_feats = gen_tsne_feats(final_feats)

# Visualize them. 
gen_plot(tsne_feats, final_labels, title='LegoFormer TSNE Plot')

# Linear Probe Functions and Classes

In [None]:
# Imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch.nn.functional as F
import os
import numpy as np

In [None]:
# Dataset class for linear probes. 
class ProbeDataset(Dataset):
    
    def __init__(self, feats, labels, idx_dict):
        
        # List of objects in this dataset split. 
        self.feats = feats
        
        # Synset labels for objects. 
        self.labels = labels
        
        # Holds ID mappings for labels. 
        self.idx_dict = idx_dict
        
    def __len__(self):
        return len(self.feats)
    
    def __getitem__(self, idx): 
        
        # Load object features and label. 
        feat = torch.Tensor(self.feats[idx]).float()
        label = self.idx_dict[self.labels[idx]]
        
        return feat, label

In [None]:
class LinearProbe(nn.Module):
    
    def __init__(self, feat_dim, n_categories):
        super().__init__()
        self.feat_dim = feat_dim
        self.n_categories = n_categories
        
        # Simple linear probe. 
        self.probe = nn.Linear(feat_dim, n_categories)
        
    def forward(self, x):
        return self.probe(x)

In [None]:
# Training loop. 
def train_loop(model, train_dataloader):
    
    # Initialize optimizer. 
    optim = torch.optim.Adam(model.parameters())
    
    # Place model on train mode. 
    model.train()
    
    # Do one epoch of updates. 
    for feats, labels in tqdm(train_dataloader):
        
        # Put on GPU. 
        feats = feats.cuda()
        labels = labels.cuda()
        
        # Zero out gradients on optimizer. 
        optim.zero_grad()
        
        # Forward pass. 
        logits = model(feats)
        
        # Compute CE loss and take update.  
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optim.step()

In [None]:
# Evaluate model accuracy on a dataset split. 
def eval_model(model, loader):
    
    # Keep track of accuracy across all datapoints. 
    all_correct = []
    
    # Place model on eval mode. 
    model.eval()
    
    # Go over entire dataset. 
    with torch.no_grad():
        for feats, labels in tqdm(loader):
            
            # Put on GPU. 
            feats = feats.cuda()
            labels = labels.cuda()
            
            # Forward pass and prediction. 
            logits = model(feats)
            preds = logits.argmax(dim=1)
            
            # Compute accuracy across batch and add to list over dataset. 
            correct = torch.eq(preds, labels).long()
            all_correct.append(correct.cpu().numpy())
            
    # Compute dataset split accuracy. 
    all_correct = np.concatenate(all_correct)
    acc = np.mean(all_correct)
    
    return acc 

In [None]:
# Training pipeline. 
def run_probe(model, train_loader, val_loader, test_loader, n_epochs):
    
    # Keep track of best model checkpoint for test set. 
    best_acc = 0.0
    
    for i in range(n_epochs):
        
        # Do a training iteration with model. 
        train_loop(model, train_loader)
        
        # Evaluate model on validation set. 
        val_acc = eval_model(model, val_loader)
        
        # Keep best performing model checkpoint. 
        if val_acc > best_acc: 
            torch.save(model.state_dict(), 'probe.pth')
        
        # Print best accuracy. 
        print('Best Acc: {}'.format(best_acc))
            
    # Evaluate best checkpoint on test set.
    model.load_state_dict(torch.load('probe.pth')) 
    test_acc = eval_model(model, test_loader)
    print('Probe test performance: {}'.format(test_acc))
    
    # Get rid of probe temp path. 
    os.remove('probe.path')

In [None]:
def linear_probe(snare_objs, counts, synsets, feat_load_func, feat_dim):

    # General hyperparameters. 
    batch_size = 64
    n_epochs = 30
    n_categories = 10
    
    # TODO Set GPU number here. 
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # Load features and labels.
    feats = feat_load_func(snare_objs)

    # Filter features by top 10 most occurring categories. 
    final_feats, final_labels = filter_by_top_k(counts, feats, snare_objs, synsets)

    # Compute dataset split lengths. 
    train_len = int(float(len(final_feats)) * 0.8)
    val_len = int(float(len(final_feats)) * 0.1)

    # Compute ID dictionary for labels. 
    label_list = list(synsets)
    idx_dict = {label_list[i]: i for i in range(len(label_list))} 

    # Split into datasets. 
    train_dataset = ProbeDataset(final_feats[:train_len], final_labels[:train_len], idx_dict)
    val_dataset = ProbeDataset(final_feats[train_len:train_len+val_len], final_labels[train_len:train_len+val_len], idx_dict)
    test_dataset = ProbeDataset(final_feats[train_len+val_len:], final_labels[train_len+val_len:], idx_dict)

    # Form dataloaders. 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Instantiate model. 
    model = LinearProbe(feat_dim, 10)
    model.cuda()
    
    # Run linear probe. 
    run_probe(model, train_loader, val_loader, test_loader, n_epochs)

# CLIP Linear Probe

In [None]:
print('Training CLIP Linear Probe...')
linear_probe(snare_objs, counts, synsets, load_clip_features, 512)