In [None]:
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import sys
sys.path.append("../")
from torchvision import datasets, transforms
import os
import pickle
from tqdm import tqdm
from data_utils import prepare_drift_dataset
from model_utils import initialize_model
from training_utils import train_model
import os


print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
model_name = "resnet"
input_size = 224
batch_size = 8
num_epochs = 10
n_train = 150
n_drift = 50
SEED = 4
num_classes = 5
feature_extract = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_root = '/path/dataset/metadataset/MetaDataset/subsets'
experiment_root = f'/path/outputs/conceptualexplanations/metadataset/{model_name}'


In [None]:
target_classes = ["dog", "cat", "bird", "bear", "elephant"]
out_dists = {}
out_dists["dog"] = ["dog(water)", "cat(water)"]
out_dists["cat"] = ["dog(water)", "cat(water)"]
out_dists["bird"] = ["bear(rock)", "bird(rock)"]
out_dists["bear"] = ["bear(rock)", "bird(rock)"]
out_dists["elephant"] = ["bear(grass)", "elephant(water)"]
c_alternatives = {"car": "water", "water": "car",
                  "rock": "water"}
class_concepts = {animal: os.listdir(os.path.join(data_root, animal)) for animal in target_classes}
experiments = ['dog(chair)',
                'cat(cabinet)',
               'dog(snow)',
               'dog(car)',
               'dog(horse)',
               'bird(water)',
               'dog(water)',
               'dog(fence)',
               'elephant(building)',
               'cat(keyboard)',
               'dog(sand)',
               'cat(computer)',
               'dog(bed)',
               'cat(bed)',
               'cat(book)',
               'dog(grass)',
               'cat(mirror)',
               'bird(sand)',
               'bear(chair)',
               'cat(grass)']

In [None]:
valid_experiments = {}
for animal, concept_list in class_concepts.items():
    for concept in concept_list:
        if concept not in experiments:
            continue
        concept_path = os.path.join(data_root, animal, concept)
        im_count = len(os.listdir(concept_path))
        if im_count < 50:
            print(f"#Images for {concept}: {im_count}. Skipping for now.")
            continue
        in_distributions = [concept if a==animal else a for a in target_classes]
        out_distributions = []
        valid_experiments[concept] = {}
        valid_experiments[concept]["animal"] = animal
        valid_experiments[concept]["path"] = concept_path
        valid_experiments[concept]["im_count"] = im_count
        valid_experiments[concept]["n_train"] = min(im_count, n_train)
        valid_experiments[concept]["in_distributions"] = [concept if a==animal else a for a in target_classes]
        out_dist = out_dists[animal]
        if concept in out_dist:
            c_name = concept.split("(")[1][:-1]
            alt = c_alternatives[c_name]
            out_dist = [d.replace(c_name, alt) for d in out_dist]
        valid_experiments[concept]["out_distributions"] = out_dist

## Create Datasets and Train models

In [None]:
all_folders = {}

for concept, concept_config in valid_experiments.items():
    if os.path.exists(os.path.join(experiment_root, concept)):
        print(f"{concept} model is already trained!")
        continue
    print(concept, concept_config)
    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    try:
         folders = prepare_drift_dataset(data_root, experiment_root, 
                                        concept_config["in_distributions"], 
                                        concept_config["out_distributions"], seed=SEED,
                                        n_train=concept_config["n_train"], n_drift=n_drift)
    except Exception as e:
        print(f"Skipping this experiment due to the following error: {e}")
        continue
    
    all_folders[concept] = folders
    
    # Create training and validation datasets
    image_datasets = {x: datasets.ImageFolder(folders[x], data_transforms[x]) for x in ['train', 'val']}
    # Create training and validation dataloaders
    dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}
    print("Datasets and loaders are ready.")
    model_ft, input_size = initialize_model(num_classes, feature_extract, use_pretrained=True, model_name=model_name)
    model_ft = model_ft.to(device)
    
    params_to_update = model_ft.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    
    # Observe that all parameters are being optimized
    optimizer_ft = optim.Adam(params_to_update, lr=0.001)
    
    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()    
    
    model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"),
                                device=device)
    best_val_acc = np.max(hist).item()
    hist = [h.item() for h in hist]

    result_log = {}
    result_log["val_acc_hist"] = hist
    result_log["best_val_acc"] = best_val_acc
    result_log["in_dists"] = concept_config["in_distributions"]
    result_log["out_dists"] = concept_config["out_distributions"]
    result_log["seed"] = SEED
    with open(os.path.join(folders["result"], "result_log.pkl"), "wb") as f:
        pickle.dump(result_log, f)
    
    with open(os.path.join(folders["result"], "concept_config.pkl"), "wb") as f:
        pickle.dump(concept_config, f)
        
    torch.save(model_ft, open(os.path.join(folders["result"], "confounded-model.pt"), "wb"))
