# co-training

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
from tqdm import tqdm
import numpy as np
import copy
import random
import pickle
import os

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.block1 = nn.Sequential(
            # C X H X W
            # 3 x 28 x 28 (input)
            nn.Conv2d(3, 32, 3), # 32 x 26 x 26
            nn.MaxPool2d(2),
            nn.ReLU(),
            # 32 x 13 x 13
            nn.Conv2d(32, 64, 3), # 64 x 11 x 11
            nn.MaxPool2d(2),
            nn.ReLU(),
            # 64 x 5 x 5
            nn.Conv2d(64, 128, 3), # 128 x 3 x 3
            nn.ReLU()
        )

        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 3 * 3, 3),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.fc1(x)
        return x

In [4]:
def train(loader, model, loss_fn, optimizer, device):
    size = len(loader.dataset)
    model.train()
    for batch, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss, current = loss.item(), (batch + 1) * len(X)
        print(f"loss: {loss:>7f} [{current:5d} / {size:>5d}]")
        
        # if batch % 100 == 0:
        #     loss, current = loss.item(), (batch + 1) * len(X)
        #     print(f"loss: {loss:>7f} [{current:5d} / {size:>5d}]")

In [5]:
def predict(loader, model, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            output = model(X)
            predictions.append(output)
    return torch.cat(predictions) # output shape (# instances, # outputs)

In [6]:
# takes in a Tensor of shape e.g. (# instances, # prob outputs) and returns a tuple
# (Tensor[top probabilities], Tensor[predicted labels], Tensor[instance indexes])
def get_topk_pred(pred, k):
    prob, label = torch.max(pred, 1)
    idx = torch.argsort(prob, descending=True)[:k]
    return prob[idx].cpu(), label[idx].cpu(), idx[idx].cpu() # ...?

In [7]:
def remove_collisions(lbl_model0, lbl_model1, idx_model0, idx_model1):
    # find instances and indices of instances that have
    # been labeled as most confident by both model0, model1
    inter, idx_inter0, idx_inter1 = np.intersect1d(
                                        idx_model0,
                                        idx_model1,
                                        return_indices=True)

    print(f"Number of predictions (model0): {len(idx_model0)}")
    print(f"Number of predictions (model1): {len(idx_model1)}")
    print(f"Found {len(inter)} instances in predict(model0) INTERSECT predict(model1)")
    # print(f"Intersection: {inter}")

    # bool mask to identify the conflicting predictions (collision)
    mask_coll = lbl_model0[idx_inter0] != lbl_model1[idx_inter1]
    collisions = inter[mask_coll]

    print(f"Found {len(collisions)} conflicting predictions")

    if (len(collisions) > 0):
        print(f"Collisions: {collisions}")
        # find where these collisions are actually at
        # in their respective lists, and remove them...
        idx_coll0 = idx_inter0[mask_coll]
        idx_coll1 = idx_inter1[mask_coll]

        # TODO we probably want to keep some log some of
        # these of things... something like that... somehow

        # masks to remove the instances with conflicting predictions
        mask0 = np.ones(len(idx_model0), dtype=bool)
        mask0[idx_coll0] = False
        mask1 = np.ones(len(idx_model1), dtype=bool)
        mask1[idx_coll1] = False

        lbl_model0 = lbl_model0[mask0]
        lbl_model1 = lbl_model1[mask1]
        idx_model0 = idx_model0[mask0]
        idx_model1 = idx_model1[mask1]

    return lbl_model0, lbl_model1, idx_model0, idx_model1


In [8]:
# train two models on two different views
# then add top k% of predictions on the unlabeled set
# to the labeled datasets
def cotrain(
        loader0, loader1, loader_unlbl,
        model0, model1, loss_fn, optimizer0, optimizer1,
        k, device):

    print("training model0 ...")
    train(loader0, model0, loss_fn, optimizer0, device)
    print("------------------------------\ntraining model1 ...")
    train(loader1, model1, loss_fn, optimizer1, device)

    print("------------------------------\nmaking predictions with model0 ...")
    pred_model0 = predict(loader_unlbl, model0, device)
    print("making predictions with model1 ...")
    pred_model1 = predict(loader_unlbl, model1, device)

    # get top-k predictions (labels, instance indexes in the dataset)
    _, lbl_topk0, idx_topk0 = get_topk_pred(
                                    pred_model0,
                                    k if k <= len(pred_model0) else len(pred_model0))
    _, lbl_topk1, idx_topk1 = get_topk_pred(
                                    pred_model1, 
                                    k if k <= len(pred_model1) else len(pred_model1))

    print(f"Number of unlabeled instances: {len(loader_unlbl.dataset.data)}")

    # what if two models predict confidently on the same instance?
    # find and remove conflicting predictions from the lists
    lbl_topk0, lbl_topk1, idx_topk0, idx_topk1 = \
    remove_collisions(lbl_topk0, lbl_topk1, idx_topk0, idx_topk1)

    # add new pseudolabeled instances to the labeled datasets
    loader0.dataset.data = torch.cat((
                                loader0.dataset.data,
                                loader_unlbl.dataset.data[idx_topk1]), 0)
    loader0.dataset.targets = torch.cat((
                                loader0.dataset.targets,
                                lbl_topk1), 0)
    loader1.dataset.data = torch.cat((
                                loader1.dataset.data,
                                loader_unlbl.dataset.data[idx_topk0]), 0)
    loader1.dataset.targets = torch.cat((
                                loader1.dataset.targets,
                                lbl_topk0), 0)

    # remove instances from unlabeled dataset
    mask_unlbl = np.ones(len(loader_unlbl.dataset.data), dtype=bool)
    mask_unlbl[idx_topk0] = False
    mask_unlbl[idx_topk1] = False
    print(f"Number of unlabeled instances to remove: {(~mask_unlbl).sum()}")
    loader_unlbl.dataset.data = loader_unlbl.dataset.data[mask_unlbl]


In [9]:
def test(loader, model, loss_fn, device):
  size = len(loader.dataset)
  num_batches = len(loader)
  model.eval()
  test_loss, correct = 0, 0
  with torch.no_grad():
    for X, y in loader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      test_loss += loss_fn(pred, y).item()
      correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

In [10]:
def add_to_imagefolder(paths, labels, dataset):
    """
    Adds the paths with the labels to an image classification dataset

    :list paths: a list of absolute image paths to add to the dataset
    :list labels: a list of labels for each path
    :Dataset dataset: the dataset to add the samples to
    """

    new_samples = list(zip(paths, labels))

    dataset.samples += new_samples

    return dataset

In [11]:
random.seed(13)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

using cuda


In [13]:
# with open('cotraining_samples_lists.pkl', 'rb') as fp:
#     dict = pickle.load(fp)
with open('cotraining_samples_lists_fixed.pkl', 'rb') as fp:
    dict = pickle.load(fp)

In [14]:
dict.keys()

dict_keys(['labeled', 'inferred', 'class_map'])

In [15]:
dict['class_map']

{'dry': 0, 'snow': 1, 'wet': 2}

In [16]:
# TODO update these transforms
# this is just a quick thing to test that it works
# with the currently working CNN architecture.
trans = transforms.Compose([
    transforms.CenterCrop(28),
    transforms.ToTensor()
    ])

In [17]:
# make dummies
data_labeled = datasets.ImageFolder('/ourdisk/hpc/ai2es/jroth/data/labeled', transform=trans)
data_unlabeled = datasets.ImageFolder('/ourdisk/hpc/ai2es/jroth/data/labeled', transform=trans)

In [18]:
(data_labeled.samples[0])

('/ourdisk/hpc/ai2es/jroth/data/labeled/bronx_allsites/dry/NYSDOT_m4er5dez4ab_2022-01-26-00-01-07.jpg',
 0)

In [19]:
dict['labeled'][0]

('/ourdisk/hpc/ai2es/jroth/data/labeled/bronx_allsites/snow/NYSDOT_uyomtjhwsay_2022-01-29-06-51-02.jpg',
 1)

In [20]:
# # whoops, the paths are relative
# # let's go ahead and replace these with the absolute paths
# root0 = '/ourdisk/hpc/ai2es/jroth/'

# # dunno there might be a better way than this. anyways
# # also the labels are incorrect... let's re-index these
# for i, (img_path, label) in enumerate(dict['labeled']):
#     dict['labeled'][i] = (img_path.replace('./', root0), label - 1)

In [21]:
dict['labeled'][0]

('/ourdisk/hpc/ai2es/jroth/data/labeled/bronx_allsites/snow/NYSDOT_uyomtjhwsay_2022-01-29-06-51-02.jpg',
 1)

In [22]:
# update samples
data_labeled.samples = dict['labeled']

# update the class_idx and stuff
data_labeled.class_to_idx = dict['class_map']
data_labeled.classes = list(dict['class_map'].keys())

In [23]:
dict['inferred'][0]

('/ourdisk/hpc/ai2es/datasets/DOT/Skyline_6464/20220129/I_87_at_Interchange_3_(Yonkers_Mile_Square_Road)__Northbound__Skyline_6464_2022-01-29-06:50:09.jpg',
 1)

In [24]:
# # same here
# root1 = '/ourdisk/hpc/ai2es/'

# for i, (img_path, label) in enumerate(dict['inferred']):
#     dict['inferred'][i] = (img_path.replace('../', root1), label - 1)

In [25]:
dict['inferred'][0]

('/ourdisk/hpc/ai2es/datasets/DOT/Skyline_6464/20220129/I_87_at_Interchange_3_(Yonkers_Mile_Square_Road)__Northbound__Skyline_6464_2022-01-29-06:50:09.jpg',
 1)

In [26]:
# update samples list
data_unlabeled.samples = dict['inferred']

# update class idx and stuff
data_unlabeled.class_to_idx = dict['class_map']
data_unlabeled.classes = list(dict['class_map'].keys())

In [27]:
# # forgot how to python but whatever this does the job
# for clazz, idx in dict['class_map'].items():
#     dict['class_map'][clazz] = idx - 1

# dict['class_map']

In [28]:
# # pickle this
# with open('cotraining_samples_lists_fixed.pkl', 'wb') as fp:
#     pickle.dump(dict, fp)

In [29]:
len(data_labeled)

4303

In [30]:
len(dict['labeled'])

4303

In [31]:
len(data_labeled)

4303

In [32]:
model0 = ConvNet().to(device)

In [33]:
batch_size = 64
loader_labeled = DataLoader(data_labeled, batch_size, False)
loader_unlabeled = DataLoader(data_unlabeled, batch_size, False)

In [34]:
loss_fn = nn.CrossEntropyLoss()
optimizer0 = torch.optim.RMSprop(model0.parameters(), lr=1e-3)