In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import os, sys
import glob
import numpy as np
import torch
import pandas as pd
import random
import torch.nn as nn
import pickle
import torch.nn.functional as F

In [3]:
from mstcn_model import *
from utility.adaptive_data_loader import Breakfast, collate_fn_override
from utils import calculate_mof, dotdict

In [4]:
os.environ["WANDB_API_KEY"] = "992b3b1371ba79f48484cfca522b3786d7fa52c2"
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdipika_singhania[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
seed = 42

# Ensure deterministic behavior
def set_seed():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
set_seed()

# Device configuration
os.environ['CUDA_VISIBLE_DEVICES']='7'
# os.environ['CUDA_LAUNCH_BLOCKING']='6'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
config = dotdict(
    epochs=500,
    num_class=48,
    batch_size=8,
    learning_rate=5e-4,
    weight_decay=0,
    dataset="Breakfast",
    architecture="unet-ensemble",
    features_file_name="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/",
    chunk_size=1,
    max_frames_per_video=1200,
    feature_size=2048,
    ground_truth_files_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/",
    label_id_csv="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv",
    gamma=0.1,
    step_size=500,
    split=4,
#     output_dir="/mnt/data/ar-datasets/dipika/breakfast/ms_tcn/data/breakfast/results/unsuper-finetune-split2-0.05-data-llr/",
    output_dir="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split4/",
    project_name="breakfast-split-4",
    train_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split{}.bundle",
    test_split_file="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split{}.bundle",
    all_files="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt",
    cutoff=8,
    data_per = 0.2,
    budget=40,
    semi_supervised_split="/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/semi_supervised/train.split{}_amt{}.bundle")

config.train_split_file = config.train_split_file.format(config.split)
config.semi_supervised_split = config.semi_supervised_split.format(config.split, config.data_per)
config.test_split_file = config.test_split_file.format(config.split)

if not os.path.exists(config.output_dir):
    os.mkdir(config.output_dir)

print(config)

{'epochs': 500, 'num_class': 48, 'batch_size': 8, 'learning_rate': 0.0005, 'weight_decay': 0, 'dataset': 'Breakfast', 'architecture': 'unet-ensemble', 'features_file_name': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/features/', 'chunk_size': 1, 'max_frames_per_video': 1200, 'feature_size': 2048, 'ground_truth_files_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/groundTruth/', 'label_id_csv': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/mapping.csv', 'gamma': 0.1, 'step_size': 500, 'split': 4, 'output_dir': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split4/', 'project_name': 'breakfast-split-4', 'train_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/train.split4.bundle', 'test_split_file': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/test.split4.bundle', 'all_files': '/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/splits/all_files.txt', 'cutoff': 8, 'data_per': 0.2, 'budget': 40, 'semi_supe

In [7]:
df=pd.read_csv(config.label_id_csv)
label_id_to_label_name = {}
label_name_to_label_id_dict = {}
for i, ele in df.iterrows():
    label_id_to_label_name[ele.label_id] = ele.label_name
    label_name_to_label_id_dict[ele.label_name] = ele.label_id

In [8]:
traindataset = Breakfast(config, fold='train', fold_file_name=config.train_split_file)
testdataset = Breakfast(config, fold='test', fold_file_name=config.test_split_file)

Number of videos logged in train fold is 1136
Number of videos not found in train fold is 0
Number of videos logged in test fold is 576
Number of videos not found in test fold is 0


In [9]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
trainloader = torch.utils.data.DataLoader(dataset=traindataset,
                                          batch_size=config.batch_size, 
                                          shuffle=True,
                                          pin_memory=True, num_workers=4, 
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)
testloader = torch.utils.data.DataLoader(dataset=testdataset,
                                          batch_size=config.batch_size, 
                                          shuffle=False,
                                          pin_memory=True, num_workers=4,
                                          collate_fn=lambda x: collate_fn_override(x, config.max_frames_per_video),
                                          worker_init_fn=_init_fn)

In [10]:
set_seed()
model = MultiStageModel(num_stages=4, num_layers=10, num_f_maps=64, dim=2048, num_classes=48).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Requires loaded_vidid_selected_frames, boundaries_dict
ce_criterion = nn.CrossEntropyLoss(ignore_index=-100)
mse_criterion = nn.MSELoss(reduction='none')

In [11]:
pseudo_labels_dir = "/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast/length_segmentation_output/"
def get_single_random(output_p, video_ids):
    # Generate target for only timestamps. Do not generate pseudo labels at first 30 epochs.
    boundary_target_tensor = torch.ones((output_p.shape[0], output_p.shape[2]), dtype=torch.long, 
                                        device=output_p.device) * (-100)
    for iter_num, cur_vidid in enumerate(video_ids):
        pseudo_l = open(pseudo_labels_dir + cur_vidid + ".txt").read().split("\n")[0:-1]
        pseudo_l = [label_name_to_label_id_dict[ele] for ele in pseudo_l]
        abc = torch.tensor(pseudo_l).to(torch.long).to(boundary_target_tensor.device)
        frame_idx_tensor = torch.arange(0, len(pseudo_l), 1).to(device)
        boundary_target_tensor[iter_num, frame_idx_tensor] = abc

    return boundary_target_tensor

In [12]:
best_val_acc = 0
best_epoch = -1
for epoch in range(100):
    print("Starting Training")
    model.train()
    for i, item in enumerate(trainloader):
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)
        optimizer.zero_grad()
        
        middle_pred, predictions = model(item_0, src_mask_mse)
        psuedo_l = get_single_random(predictions[-1], item[4])
        loss = 0
        for p in predictions:
            loss += ce_criterion(p, psuedo_l)
            loss += 0.15 * torch.mean(torch.clamp(mse_criterion(F.log_softmax(p[:, :, 1:], dim=1), 
                                                                F.log_softmax(p.detach()[:, :, :-1], dim=1)), min=0,
                                        max=16) * src_mask_mse[:, :, 1:])
            

        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            with torch.no_grad():
                pred = torch.argmax(predictions[-1], dim=1)
                correct = float(torch.sum((pred == item_2) * src_mask).item())
                total = float(torch.sum(src_mask).item())
                print(f"Training:: Epoch {epoch}, Iteration {i}, Current loss {loss.item()}" +
                      f" Accuracy {correct * 100.0 / total}")
    # Calculating Expectation Step
    model.eval()

    print("Calculating Validation Data Accuracy")
    correct = 0.0
    total = 0.0
    for i, item in enumerate(testloader):
        with torch.no_grad():
            item_0 = item[0].to(device)
            item_1 = item[1].to(device)
            item_2 = item[2].to(device)
            src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
            src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

            middle_pred, predictions = model(item_0, src_mask_mse)

            pred = torch.argmax(predictions[-1], dim=1)
            correct += float(torch.sum((pred == item_2) * src_mask).item())
            total += float(torch.sum(src_mask).item())
    val_acc = correct * 100.0 / total
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), config.output_dir + "ms-tcn-best-model.wt")
    torch.save(model.state_dict(), config.output_dir + "ms-tcn-last-model.wt")
    print(f"Validation:: Epoch {epoch}, Probability Accuracy {val_acc}")

Starting Training
Training:: Epoch 0, Iteration 0, Current loss 15.743518829345703 Accuracy 3.737044714243411
Training:: Epoch 0, Iteration 10, Current loss 13.267817497253418 Accuracy 16.09777263523286
Training:: Epoch 0, Iteration 20, Current loss 13.018302917480469 Accuracy 3.632347527303964
Training:: Epoch 0, Iteration 30, Current loss 12.549955368041992 Accuracy 18.77851684899126
Training:: Epoch 0, Iteration 40, Current loss 14.110186576843262 Accuracy 8.210645526613817
Training:: Epoch 0, Iteration 50, Current loss 11.841203689575195 Accuracy 27.4247491638796
Training:: Epoch 0, Iteration 60, Current loss 11.493738174438477 Accuracy 18.26611033843301
Training:: Epoch 0, Iteration 70, Current loss 12.444272994995117 Accuracy 8.700303084328757
Training:: Epoch 0, Iteration 80, Current loss 11.369422912597656 Accuracy 12.475806451612904
Training:: Epoch 0, Iteration 90, Current loss 11.29186725616455 Accuracy 16.064356435643564
Training:: Epoch 0, Iteration 100, Current loss 10.37

Training:: Epoch 5, Iteration 70, Current loss 5.100342273712158 Accuracy 44.30602891056528
Training:: Epoch 5, Iteration 80, Current loss 6.855910301208496 Accuracy 35.39423117530301
Training:: Epoch 5, Iteration 90, Current loss 5.335494518280029 Accuracy 55.97463059172795
Training:: Epoch 5, Iteration 100, Current loss 6.217436790466309 Accuracy 37.892152783556625
Training:: Epoch 5, Iteration 110, Current loss 5.597986698150635 Accuracy 53.34494355032429
Training:: Epoch 5, Iteration 120, Current loss 8.061092376708984 Accuracy 43.213135526107386
Training:: Epoch 5, Iteration 130, Current loss 5.146203517913818 Accuracy 59.9833232261977
Training:: Epoch 5, Iteration 140, Current loss 5.598455429077148 Accuracy 52.331765636675314
Calculating Validation Data Accuracy
Validation:: Epoch 5, Probability Accuracy 45.65454676174354
Starting Training
Training:: Epoch 6, Iteration 0, Current loss 4.9656662940979 Accuracy 55.60447003047748
Training:: Epoch 6, Iteration 10, Current loss 5.175

Calculating Validation Data Accuracy
Validation:: Epoch 10, Probability Accuracy 48.58683716887038
Starting Training
Training:: Epoch 11, Iteration 0, Current loss 5.692268371582031 Accuracy 51.03734439834025
Training:: Epoch 11, Iteration 10, Current loss 4.9820475578308105 Accuracy 56.28290941543953
Training:: Epoch 11, Iteration 20, Current loss 4.41371488571167 Accuracy 64.45260921981745
Training:: Epoch 11, Iteration 30, Current loss 6.036645412445068 Accuracy 50.77559462254395
Training:: Epoch 11, Iteration 40, Current loss 6.863383769989014 Accuracy 49.708682677277245
Training:: Epoch 11, Iteration 50, Current loss 4.880638122558594 Accuracy 51.4818209593645
Training:: Epoch 11, Iteration 60, Current loss 3.9843735694885254 Accuracy 62.959502823022866
Training:: Epoch 11, Iteration 70, Current loss 3.9651095867156982 Accuracy 64.33219418566547
Training:: Epoch 11, Iteration 80, Current loss 4.728310585021973 Accuracy 58.86927177581063
Training:: Epoch 11, Iteration 90, Current l

Training:: Epoch 16, Iteration 60, Current loss 3.8104162216186523 Accuracy 55.778513961192616
Training:: Epoch 16, Iteration 70, Current loss 4.625424861907959 Accuracy 52.0742016280526
Training:: Epoch 16, Iteration 80, Current loss 4.827786922454834 Accuracy 54.47529170781667
Training:: Epoch 16, Iteration 90, Current loss 3.998288631439209 Accuracy 63.82108933509679
Training:: Epoch 16, Iteration 100, Current loss 3.2853705883026123 Accuracy 66.35520439711439
Training:: Epoch 16, Iteration 110, Current loss 3.682086706161499 Accuracy 69.16115164884891
Training:: Epoch 16, Iteration 120, Current loss 3.2372567653656006 Accuracy 69.75381008206331
Training:: Epoch 16, Iteration 130, Current loss 3.3535311222076416 Accuracy 64.36553855908694
Training:: Epoch 16, Iteration 140, Current loss 3.6133103370666504 Accuracy 62.48002949489984
Calculating Validation Data Accuracy
Validation:: Epoch 16, Probability Accuracy 49.996241805501995
Starting Training
Training:: Epoch 17, Iteration 0, C

Training:: Epoch 21, Iteration 130, Current loss 3.8583261966705322 Accuracy 60.57400838439213
Training:: Epoch 21, Iteration 140, Current loss 3.4713962078094482 Accuracy 67.2345308500309
Calculating Validation Data Accuracy
Validation:: Epoch 21, Probability Accuracy 53.779109667383445
Starting Training
Training:: Epoch 22, Iteration 0, Current loss 2.7069709300994873 Accuracy 66.0601001669449
Training:: Epoch 22, Iteration 10, Current loss 3.332634449005127 Accuracy 56.062424969987994
Training:: Epoch 22, Iteration 20, Current loss 2.816572904586792 Accuracy 57.216383140101854
Training:: Epoch 22, Iteration 30, Current loss 3.3814637660980225 Accuracy 60.04691164972635
Training:: Epoch 22, Iteration 40, Current loss 3.6120786666870117 Accuracy 66.19381657573396
Training:: Epoch 22, Iteration 50, Current loss 3.470229148864746 Accuracy 64.167577911427
Training:: Epoch 22, Iteration 60, Current loss 3.2124738693237305 Accuracy 58.29692415239427
Training:: Epoch 22, Iteration 70, Curre

Training:: Epoch 27, Iteration 30, Current loss 3.367600202560425 Accuracy 66.96640235428758
Training:: Epoch 27, Iteration 40, Current loss 2.5105695724487305 Accuracy 60.49983412584319
Training:: Epoch 27, Iteration 50, Current loss 2.40655255317688 Accuracy 62.54094030639197
Training:: Epoch 27, Iteration 60, Current loss 2.310107469558716 Accuracy 53.37612155347436
Training:: Epoch 27, Iteration 70, Current loss 2.263051986694336 Accuracy 73.6498150431566
Training:: Epoch 27, Iteration 80, Current loss 3.279843330383301 Accuracy 63.34123124246581
Training:: Epoch 27, Iteration 90, Current loss 3.3800206184387207 Accuracy 65.81521932477142
Training:: Epoch 27, Iteration 100, Current loss 2.629836320877075 Accuracy 63.8642073656474
Training:: Epoch 27, Iteration 110, Current loss 2.577601909637451 Accuracy 56.26210610856522
Training:: Epoch 27, Iteration 120, Current loss 2.681715726852417 Accuracy 65.0373679237179
Training:: Epoch 27, Iteration 130, Current loss 3.23227596282959 Acc

Training:: Epoch 32, Iteration 100, Current loss 2.675433397293091 Accuracy 69.71625286876696
Training:: Epoch 32, Iteration 110, Current loss 2.290780782699585 Accuracy 68.29881272207297
Training:: Epoch 32, Iteration 120, Current loss 3.1344687938690186 Accuracy 64.54319761668322
Training:: Epoch 32, Iteration 130, Current loss 2.9279699325561523 Accuracy 64.48895094121612
Training:: Epoch 32, Iteration 140, Current loss 2.1218113899230957 Accuracy 62.04665054757503
Calculating Validation Data Accuracy
Validation:: Epoch 32, Probability Accuracy 53.32853482702501
Starting Training
Training:: Epoch 33, Iteration 0, Current loss 2.4621336460113525 Accuracy 71.84791548048334
Training:: Epoch 33, Iteration 10, Current loss 1.9354077577590942 Accuracy 57.27280313777852
Training:: Epoch 33, Iteration 20, Current loss 2.3795955181121826 Accuracy 57.66258246936852
Training:: Epoch 33, Iteration 30, Current loss 2.7427830696105957 Accuracy 66.16317792578496
Training:: Epoch 33, Iteration 40, 

KeyboardInterrupt: 

In [16]:
config.output_dir + "ms-tcn-best-model.wt"

'/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/mstcn-lenpsuedo-full-supervised-split1/ms-tcn-best-model.wt'

In [15]:
best_epoch

10

In [24]:
torch.save(model.state_dict(),
"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-speed/final-em-maximized.wt")

In [34]:
model.load_state_dict(torch.load(f"/mnt/ssd/all_users/dipika/ms_tcn/data/breakfast//results/em-maximize-mstcn-split3/ms-tcn-initial-15-epochs.wt"))

<All keys matched successfully>

In [25]:
print("Calculating Validation Data Accuracy")
correct = 0.0
total = 0.0
for i, item in enumerate(testloader):
    with torch.no_grad():
        item_0 = item[0].to(device)
        item_1 = item[1].to(device)
        item_2 = item[2].to(device)
        src_mask = torch.arange(item_2.shape[1], device=item_2.device)[None, :] < item_1[:, None]
        src_mask_mse = src_mask.unsqueeze(1).to(torch.float32).to(device)

        middle_pred, predictions = model(item_0, src_mask_mse)

        pred = torch.argmax(predictions[-1], dim=1)
        correct += float(torch.sum((pred == item_2) * src_mask).item())
        total += float(torch.sum(src_mask).item())

print(f"Validation:: Epoch {epoch}, Probability Accuracy {correct * 100.0 / total}")

Calculating Validation Data Accuracy
Validation:: Epoch 15, Probability Accuracy 57.770101729343025


In [9]:
import pickle
# pickle.dump(video_id_boundary_frames, open("dump_dir/video_id_boundary_frames_dict.pkl", "wb"))
# pickle.dump(loaded_vidid_selected_frames, open("dump_dir/loaded_vidid_selected_frames_dict.pkl", "wb"))
pickle.dump(boundary_dict, open("dump_dir/chunk_1_video_id_boundary_frames_dict.pkl", "wb"))