In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import os, sys
import glob
import numpy as np
import torch
import pandas as pd
import random
import torch.nn as nn
import pickle
import torch.nn.functional as F

In [3]:
from mstcn_model import *
from utility.adaptive_data_loader import Breakfast, collate_fn_override
from utils import calculate_mof, dotdict

In [4]:
os.environ["WANDB_API_KEY"] = "992b3b1371ba79f48484cfca522b3786d7fa52c2"
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdipika_singhania[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
seed = 42

# Ensure deterministic behavior
def set_seed():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
set_seed()

# Device configuration
os.environ['CUDA_VISIBLE_DEVICES']='4'
# os.environ['CUDA_LAUNCH_BLOCKING']='6'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
config = dotdict(
    epochs=500,
    num_class=48,
    batch_size=8,
    learning_rate=5e-4,
    weight_decay=0,
    dataset="Breakfast",
    architecture="unet-ensemble",
    features_file_name="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/",
    chunk_size=1,
    max_frames_per_video=1200,
    feature_size=2048,
    ground_truth_files_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/",
    label_id_csv="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv",
    gamma=0.1,
    step_size=500,
    split=1,
#     output_dir="/mnt/data/ar-datasets/dipika/breakfast/ms_tcn/data/breakfast/results/unsuper-finetune-split2-0.05-data-llr/",
    output_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/",
    project_name="breakfast-split-1",
    train_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split{}.bundle",
    test_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split{}.bundle",
    all_files="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt",
    cutoff=8,
    data_per = 0.2,
    budget=40,
    semi_supervised_split="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/semi_supervised/train.split{}_amt{}.bundle")

config.train_split_file = config.train_split_file.format(config.split)
config.semi_supervised_split = config.semi_supervised_split.format(config.split, config.data_per)
config.test_split_file = config.test_split_file.format(config.split)

if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)

print(config)

{'epochs': 500, 'num_class': 48, 'batch_size': 8, 'learning_rate': 0.0005, 'weight_decay': 0, 'dataset': 'Breakfast', 'architecture': 'unet-ensemble', 'features_file_name': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/', 'chunk_size': 1, 'max_frames_per_video': 1200, 'feature_size': 2048, 'ground_truth_files_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/', 'label_id_csv': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv', 'gamma': 0.1, 'step_size': 500, 'split': 1, 'output_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/', 'project_name': 'breakfast-split-1', 'train_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split1.bundle', 'test_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split1.bundle', 'all_files': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt', 'cutoff': 8, 'data_per': 0.2, 'budget': 40, 'semi_supe

In [7]:
df=pd.read_csv(config.label_id_csv)
label_id_to_label_name = {}
label_name_to_label_id_dict = {}
for i, ele in df.iterrows():
    label_id_to_label_name[ele.label_id] = ele.label_name
    label_name_to_label_id_dict[ele.label_name] = ele.label_id

In [8]:
traindataset = Breakfast(config, fold='train', fold_file_name=config.train_split_file)
testdataset = Breakfast(config, fold='test', fold_file_name=config.test_split_file)

Number of videos logged in train fold is 1460
Number of videos not found in train fold is 0
Number of videos logged in test fold is 252
Number of videos not found in test fold is 0


In [9]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
trainloader = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=config.batch_size, 
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)
testloader = torch.utils.data.DataLoader(dataset=testdataset,
                                          batch_size=config.batch_size, 
                                          shuffle=False,
                                          pin_memory=True, num_workers=4,
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

In [10]:
set_seed()
model = MultiStageModel(num_stages=4, num_layers=10, num_f_maps=64, dim=2048, num_classes=48).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Requires loaded_vidid_selected_frames, boundaries_dict
ce_criterion = nn.CrossEntropyLoss(ignore_index=-100)
mse_criterion = nn.MSELoss(reduction='none')

In [13]:
pseudo_labels_dir = "/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/length_segmentation_output/"
def get_single_random(output_p, video_ids):
    # Generate target for only timestamps. Do not generate pseudo labels at first 30 epochs.
    boundary_target_tensor = torch.ones((output_p.shape[0], output_p.shape[2]), dtype=torch.long, 
                                        device=output_p.device) * (-100)
    for iter_num, cur_vidid in enumerate(video_ids):
        pseudo_l = open(pseudo_labels_dir + cur_vidid + ".txt").read().split("\n")[0:-1]
        pseudo_l = [label_name_to_label_id_dict[ele] for ele in pseudo_l]
        abc = torch.tensor(pseudo_l).to(torch.long).to(boundary_target_tensor.device)
        frame_idx_tensor = torch.arange(0, len(pseudo_l), 1).to(device)
        boundary_target_tensor[iter_num, frame_idx_tensor] = abc

    return boundary_target_tensor

In [14]:
best_val_acc = 0
best_epoch = -1
for epoch in range(100):
    print("Starting Training")
    model.train()
    for i, item in enumerate(trainloader):
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
        optimizer.zero_grad()
        
        middle_pred, predictions = model(item_0, src_mask_mse)
        psuedo_l = get_single_random(predictions[-1], item[4])
        loss = 0
        for p in predictions:
            loss += ce_criterion(p, psuedo_l)
            loss += 0.15 * torch.mean(torch.clamp(mse_criterion(F.log_softmax(p[:, :, 1:], dim=1), 
                                                                F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
                                        max=16) * src_mask_mse[:, :, 1:])
            

        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            with torch.no_grad():
                pred = torch.argmax(predictions[-1], dim=1)
                correct = float(torch.sum((pred == item_2) * src_mask).item())
                total = float(torch.sum(src_mask).item())
                print(f"Training:: Epoch {epoch}, Iteration {i}, Current loss {loss.item()}" +
                      f" Accuracy {correct * 100.0 / total}")
    # Calculating Expectation Step
    model.eval()

    print("Calculating Validation Data Accuracy")
    correct = 0.0
    total = 0.0
    for i, item in enumerate(testloader):
        with torch.no_grad():
            item_0 = item[0].to(device)
            item_1 = item[1].to(device)
            item_2 = item[2].to(device)
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

            middle_pred, predictions = model(item_0, src_mask_mse)

            pred = torch.argmax(predictions[-1], dim=1)
            correct += float(torch.sum((pred == item_2) * src_mask).item())
            total += float(torch.sum(src_mask).item())
    val_acc = correct * 100.0 / total
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), config.output_dir + "ms-tcn-best-model.wt")
    torch.save(model.state_dict(), config.output_dir + "ms-tcn-last-model.wt")
    print(f"Validation:: Epoch {epoch}, Probability Accuracy {val_acc}")

Starting Training
Training:: Epoch 0, Iteration 0, Current loss 15.688101768493652 Accuracy 1.7120043307619435
Training:: Epoch 0, Iteration 10, Current loss 14.867568969726562 Accuracy 3.1712991306477285
Training:: Epoch 0, Iteration 20, Current loss 12.627345085144043 Accuracy 8.095451921226505
Training:: Epoch 0, Iteration 30, Current loss 13.397523880004883 Accuracy 10.806481142438667
Training:: Epoch 0, Iteration 40, Current loss 12.8862886428833 Accuracy 25.437440076701822
Training:: Epoch 0, Iteration 50, Current loss 13.115045547485352 Accuracy 14.34950178605001
Training:: Epoch 0, Iteration 60, Current loss 13.697102546691895 Accuracy 5.139224639627541
Training:: Epoch 0, Iteration 70, Current loss 10.90279769897461 Accuracy 31.25250501002004
Training:: Epoch 0, Iteration 80, Current loss 12.055412292480469 Accuracy 14.958582743817644
Training:: Epoch 0, Iteration 90, Current loss 12.215043067932129 Accuracy 9.596265237313046
Training:: Epoch 0, Iteration 100, Current loss 11.

Training:: Epoch 4, Iteration 70, Current loss 6.280013084411621 Accuracy 36.54118313048401
Training:: Epoch 4, Iteration 80, Current loss 7.169579982757568 Accuracy 36.03206765923707
Training:: Epoch 4, Iteration 90, Current loss 6.607783317565918 Accuracy 34.43472446648465
Training:: Epoch 4, Iteration 100, Current loss 5.889108657836914 Accuracy 51.48262813897489
Training:: Epoch 4, Iteration 110, Current loss 6.916120529174805 Accuracy 31.214701927386823
Training:: Epoch 4, Iteration 120, Current loss 5.950628757476807 Accuracy 50.024175027196904
Training:: Epoch 4, Iteration 130, Current loss 6.824543476104736 Accuracy 44.49109915584132
Training:: Epoch 4, Iteration 140, Current loss 5.868003845214844 Accuracy 43.218733080671356
Training:: Epoch 4, Iteration 150, Current loss 5.4972429275512695 Accuracy 52.68215081132676
Training:: Epoch 4, Iteration 160, Current loss 7.453354835510254 Accuracy 36.96352333158565
Training:: Epoch 4, Iteration 170, Current loss 5.6124420166015625 Ac

Training:: Epoch 8, Iteration 150, Current loss 5.290956497192383 Accuracy 57.1322505800464
Training:: Epoch 8, Iteration 160, Current loss 4.486000061035156 Accuracy 50.41947005565246
Training:: Epoch 8, Iteration 170, Current loss 4.723947525024414 Accuracy 58.09263061829797
Training:: Epoch 8, Iteration 180, Current loss 7.25172233581543 Accuracy 34.866059099696216
Calculating Validation Data Accuracy
Validation:: Epoch 8, Probability Accuracy 51.56285242826786
Starting Training
Training:: Epoch 9, Iteration 0, Current loss 5.134907245635986 Accuracy 44.44444444444444
Training:: Epoch 9, Iteration 10, Current loss 5.608646869659424 Accuracy 55.45439288411513
Training:: Epoch 9, Iteration 20, Current loss 6.688746929168701 Accuracy 49.37062416207359
Training:: Epoch 9, Iteration 30, Current loss 6.824410438537598 Accuracy 37.94405879951
Training:: Epoch 9, Iteration 40, Current loss 6.757987976074219 Accuracy 43.35959307735252
Training:: Epoch 9, Iteration 50, Current loss 6.34982967

Training:: Epoch 13, Iteration 20, Current loss 3.784616231918335 Accuracy 57.015669515669515
Training:: Epoch 13, Iteration 30, Current loss 5.065441131591797 Accuracy 47.51049402647724
Training:: Epoch 13, Iteration 40, Current loss 6.399113655090332 Accuracy 57.605335936229054
Training:: Epoch 13, Iteration 50, Current loss 3.8770017623901367 Accuracy 51.5831894070236
Training:: Epoch 13, Iteration 60, Current loss 4.6334686279296875 Accuracy 64.90309062336301
Training:: Epoch 13, Iteration 70, Current loss 3.9832348823547363 Accuracy 54.59950734732014
Training:: Epoch 13, Iteration 80, Current loss 4.956575393676758 Accuracy 65.93398487921814
Training:: Epoch 13, Iteration 90, Current loss 4.285998344421387 Accuracy 61.68397812233429
Training:: Epoch 13, Iteration 100, Current loss 5.584327697753906 Accuracy 57.11390191558699
Training:: Epoch 13, Iteration 110, Current loss 3.568197011947632 Accuracy 67.48174284073193
Training:: Epoch 13, Iteration 120, Current loss 4.3211708068847

KeyboardInterrupt: 

In [16]:
config.output_dir + "ms-tcn-best-model.wt"

'/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/ms-tcn-best-model.wt'

In [15]:
best_epoch

10

In [24]:
torch.save(model.state_dict(),
"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-speed/final-em-maximized.wt")

In [34]:
model.load_state_dict(torch.load(f"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split3/ms-tcn-initial-15-epochs.wt"))

<All keys matched successfully>

In [25]:
print("Calculating Validation Data Accuracy")
correct = 0.0
total = 0.0
for i, item in enumerate(testloader):
    with torch.no_grad():
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

        middle_pred, predictions = model(item_0, src_mask_mse)

        pred = torch.argmax(predictions[-1], dim=1)
        correct += float(torch.sum((pred == item_2) * src_mask).item())
        total += float(torch.sum(src_mask).item())

print(f"Validation:: Epoch {epoch}, Probability Accuracy {correct * 100.0 / total}")

Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 57.770101729343025


In [9]:
import pickle
# pickle.dump(video_id_boundary_frames, open("dump_dir/video_id_boundary_frames_dict.pkl", "wb"))
# pickle.dump(loaded_vidid_selected_frames, open("dump_dir/loaded_vidid_selected_frames_dict.pkl", "wb"))
pickle.dump(boundary_dict, open("dump_dir/chunk_1_video_id_boundary_frames_dict.pkl", "wb"))