# Pseudolabel images with a teacher model
Load in a model trained on the (smaller) CUB 200 2010 set and pseudolabel images on larger CUB 200 2011 set.

In [30]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import datetime
from src.shared import weights_path, data_path, gen_train_val, data_transforms, should_print
from src.PseudolabelDataset import PseudolabelDataset


plt.ion()   # interactive mode

In [31]:
is_local = True

In [32]:
# path of weights transfered from CCV
TRAINED_MODEL_PATH = "weights/resnet50_CUB200_66pct"
# TRAINED_MODEL_PATH = "weights/student_CUB200_Dec4"

In [33]:
# Load smaller labeled set (CUB 200 2010)
labeled_dataset_name = "CUB_200"
labeled_data_dir = data_path(labeled_dataset_name)
labeled_dataset = datasets.ImageFolder(
    labeled_data_dir, data_transforms['train'])

labeled_image_datasets = gen_train_val(labeled_dataset)
labeled_dataloaders = {x: torch.utils.data.DataLoader(labeled_image_datasets[x], batch_size=BATCH_SIZE,
                                              shuffle=True)
               for x in ['train', 'val']}
labeled_dataset_sizes = {x: len(labeled_image_datasets[x]) for x in ['train', 'val']}

class_names = labeled_dataset.classes

# Load larger unlabeled set (CUB 200 2011)
unlabeled_data_dir = data_path("CUB_200_2011/CUB_200_2011/images", is_local=is_local)
unlabeled_dataset = datasets.ImageFolder(
    unlabeled_data_dir, data_transforms['train'])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

FileNotFoundError: [Errno 2] No such file or directory: '/users/tjiang12/data/tjiang12/CUB_200'

In [None]:
# Configure teacher model's architecture
teacher = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, num_classes)

teacher = teacher.to(device)

# Load in trained model as teacher
teacher.load_state_dict(torch.load(TRAINED_MODEL_PATH, map_location=torch.device('cpu')))

In [None]:
# Experiment: Evaluate trained model's performance on unlabeled data
teacher.eval()

with torch.no_grad():
    running_corrects = 0
    seen = 0
    for i, (inputs, labels) in enumerate(unlabeled_dataloader):
        inputs = inputs.to(device)

        outputs = teacher(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        running_corrects += torch.sum(preds == labels.data)
        seen += len(inputs)
        print(running_corrects, "/", seen, running_corrects / seen)

In [None]:
outputs
sm = nn.Softmax(dim=1)
probs = sm(outputs)

# sanity check
assert abs(torch.sum(probs[0]) - 1) <= 0.005

print(torch.max(probs, dim=1).values)
print(torch.max(probs, dim=1).values > 0.75)
valid_confidence = torch.max(probs, dim=1).values > 0.3
probs[valid_confidence].shape

In [None]:
labeled_dataset.__getitem__(2)
teacher(labeled_dataset)

In [None]:
sm = nn.Softmax(dim=1)

# Write a dataloader that randomly picks either the pseudo labeled or correctly labeled data
class StudentDataset(torch.utils.data.Dataset):

    def __init__(self, teacher, labeled, unlabeled, transform=None):
        self.teacher = teacher
        self.labeled = labeled
        self.unlabeled = unlabeled
    
    def __len__(self):
        return len(self.labeled) + len(self.unlabeled)

    def __getitem__(self, idx):
        if idx < len(self.labeled):
#             print("labeled")
            return self.labeled.__getitem__(idx)
        
        idx = idx - len(self.labeled)
#         print("unlabeled")
        img, truelabel = self.unlabeled.__getitem__(idx)
        logits = self.teacher(torch.reshape(img, (1, 3, 224, 224)))
        probs = sm(logits)
        value, prediction = torch.max(probs, dim=1)
        pseudolabel = int(prediction)
#         if value < 0.75:
#             print("LOW CONF", truelabel, pseudolabel)
#             return img, -1
#         print(truelabel, pseudolabel, value)
        return img, pseudolabel
    
s = StudentDataset(teacher, labeled_dataset, unlabeled_dataset)

# for i, (img, label) in enumerate(s):
#     if i > len(s.labeled):
#         print(i)
# print(labeled_dataset.__getitem__(0))
# print(s.__getitem__(0))

# # print(unlabeled_dataset.__getitem__(0))
# for i in range(len(s.labeled), len(s), 30):
#     print(s.__getitem__(i))
# print(len(s.labeled))
# print(len(s.unlabeled))

In [None]:
img, _ = s.__getitem__(0)
print(img.shape)

logits = teacher(torch.reshape(img, (1, 3, 224, 224)))
probs = sm(logits)
value, prediction = torch.max(probs, dim=1)

print(value, value > 0.3)

In [None]:
combined_loader = torch.utils.data.DataLoader(s, batch_size=16, shuffle=True)

In [None]:
# Maybe only keep those > 0.85 
for i, j in combined_loader:
    print(j)
#     valid = j > 0
#     print(torch.sum(valid))
#     print(i[valid].shape)
#     print(j[valid].shape)

In [None]:
# Configure student model's architecture
student = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = student.fc.in_features
student.fc = nn.Linear(num_ftrs, num_classes)

student = student.to(device)

In [None]:
def should_print(i):
    print_every = 15 if is_local else 200
    return i % print_every == 0

sm = nn.Softmax(dim=1)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # track best model weights
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # train for num_epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a phase: train and val
        for phase in ['train', 'val']:
            epoch_begin = time.time()
            if phase == 'train':
                # sets it in training mode
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # get batch
            for i, (inputs, labels) in enumerate(combined_loader):
                valid = labels != -1
                print(valid)
                print(inputs.shape)
                inputs = inputs[valid]
                labels = labels[valid]
                print(torch.sum(valid))
                print(inputs.shape)
                print(labels.shape)
                print(labels)
                if should_print(i):
                    time_elapsed = time.time() - epoch_begin
                    print(
                        i + 1, '/', len(combined_loader), int(time_elapsed), 'seconds')
                    print('ETA:', datetime.timedelta(seconds=int(
                        (time_elapsed / (i + 1)) * (len(combined_loader) - i))))
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero out the previous gradient
                optimizer.zero_grad()

                # dunno what this `with` does
                with torch.set_grad_enabled(phase == 'train'):
                    # forward pass
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(running_corrects.double(), "/", dataset_sizes[phase])
            print(running_corrects.double() / dataset_sizes[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                print("UPDATE:", best_acc, "to", epoch_acc)
                print("Saving model to", PATH)
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), PATH)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(student.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

student = train_model(student, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

In [None]:
softmax = nn.Softmax(dim=1)

class PseudolabelDataset(torch.utils.data.Dataset):
    def __init__(self, data, teacher, threshold=0, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
        self.data = data
        self.teacher = teacher
        self.threshold = threshold
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, _ = self.data.__getitem__(idx)
        img = img.to(self.device)

        logits = self.teacher(torch.reshape(img, (1, 3, 224, 224)))
        probs = softmax(logits)
        value, prediction = torch.max(probs, dim=1)
        pseudolabel = int(prediction)
        if value < self.threshold:
            return img, -1
        return img, pseudolabel


In [None]:
pd = PseudolabelDataset(unlabeled_dataset, teacher, device=device)

In [None]:
for num, (i, j) in enumerate(pd):
    if num == 0:
        print(i)
    pass

In [None]:
unlabeled_dataset.__getitem__(0)