In [1]:
import os
import random
import time

import numpy as np
import pandas as pd
import scipy as sc
from scipy.io import wavfile
from scipy import signal
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms import Compose

from sklearn.metrics import roc_curve, roc_auc_score

import tensorboardX
from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
import warnings
warnings.filterwarnings('ignore') # scipy throws future warnings on fft (known bug)

In [3]:
torch.backends.cudnn.deterministic = True # False by default

In [4]:
class AudioFrameDataset(Dataset):
    '''Train only'''

    def __init__(self, path_to_data, path_to_split, transform=None, seed=13): 
        self.path_to_data = path_to_data
        voice_set_labels = pd.read_table(path_to_split, sep=' ', names=['path', 'phase'])
        voice_set_labels.replace({'_000': '/0', '.wav$': ''}, inplace=True, regex=True)
        mask = (voice_set_labels.phase == 1) | (voice_set_labels.phase == 3)
        dataset = voice_set_labels[mask].reset_index(drop=True)
        self.dataset = dataset['path'][:300] # here
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ### VISUAL INPUT
        video_path = os.path.join(self.path_to_data, 'video', self.dataset[idx] + '.txt')
        frames = pd.read_table(video_path, skiprows=6, usecols=['FRAME '])
        earliest = frames['FRAME '].iloc[0]
        latest = frames['FRAME '].iloc[-1]
        frame_list = np.arange(earliest, latest+1)
        mask = np.where(frame_list % 25 == 0)
        # only 20 per each face-track (see the asterics on the project page)
        # frames_sec = frame_list[mask]
        frames_sec = frame_list[mask][:20]
        selected_frame = np.random.choice(frames_sec)
        selected_frame_filename = '{0:07d}.jpg'.format(selected_frame)
        selected_frame_path = os.path.join(self.path_to_data, 'video', 
                                           self.dataset[idx][:-5] + selected_frame_filename)
        frame = cv2.cvtColor(cv2.imread(selected_frame_path), cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (224, 224), interpolation=cv2.INTER_CUBIC)
        
        ### AUDIO INPUT
        audio_path = os.path.join(self.path_to_data, 'audio', self.dataset[idx] + '.wav')
        sample_rate, samples = wavfile.read(audio_path)
        
        ## parameters
        segment_len = 3
        window = 'hamming'
        window_width = int(sample_rate * 0.025)
        overlap = int(sample_rate * (0.025 - 0.010))
        FFT_len = 2 ** (window_width - 1).bit_length()
        pre_emphasis = 0.97
        
        # preemphasis filter
        samples = np.append(samples[0], samples[1:] - pre_emphasis * samples[:-1])
    
        upper_bound = len(samples) - segment_len * sample_rate
        start = np.random.randint(0, upper_bound)
        end = start + segment_len * sample_rate
        # Note, it produces 512x298 and I don't know why there is some subtle
        # differences. However, since the model averages second axis it doesn't
        # matter from the computational POV.
        _, _, spectrogram = signal.spectrogram(samples[start:end], sample_rate, 
                                               window=window, nfft=FFT_len, 
                                               nperseg=window_width, noverlap=overlap, 
                                               mode='magnitude', return_onesided=False)
        spectrogram *= sample_rate / 10
        
        if self.transform:
            frame = frame.astype(np.float32)
            spectrogram = spectrogram.astype(np.float32)
            frame, spectrogram = self.transform((frame, spectrogram))
        
        return frame, spectrogram

In [5]:
class Normalize(object):
    """Normalizes both face (mean) and voice spectrogram (mean-varience)"""
    
    def __call__(self, sample):
        frame, spectrogram = sample

        ## FACE (H, W, C)
        # mean normalization for every image (not batch)
        mu = frame.mean(axis=(0, 1))
        frame = frame - mu
        
        ## VOICE (Freq, Time)
        # mean-variance normalization for every spectrogram (not batch-wise)
        mu = spectrogram.mean(axis=1).reshape(512, 1)
        sigma = spectrogram.std(axis=1).reshape(512, 1)
        spectrogram = (spectrogram - mu) / sigma

        return frame, spectrogram

class RandomHorizontalFlip(object):
    '''Horizontally flip the given Image ndarray randomly with a given probability.'''
    
    def __init__(self, p=0.5):
        self.p = p
    
    def __call__(self, sample):
        frame, spectrogram = sample

        if random.random() < self.p:
            return cv2.flip(frame, 1), spectrogram
        
        return frame, spectrogram

class ColorJittering(object):
    '''Given Image ndarray performs brightness and 
    saturation jittering. It is not mentioned in the paper but I guess 
    the authors used MatConvNet but do not mention any specific augmentation
    parameters. So, I made my wind guess regarding the parameters and implemented
    augmentation in the following fashion as in:
    http://www.vlfeat.org/matconvnet/mfiles/vl_imreadjpeg/
    and the Section 3.5 of the manual
    http://www.vlfeat.org/matconvnet/matconvnet-manual.pdf'''
    
    def __init__(self, brightness=[255/25.5, 255/25.5, 255/25.5], saturation=0.5):
        # brightness
        self.B = np.array(brightness, dtype=np.float32)
        # saturation
        self.S = saturation
    
    def __call__(self, sample):
        frame, spectrogram = sample
        
        # brightness
        # gives an error w/o float32 -- normal() returns float64
        w = np.float32(np.random.normal(size=3))
        b = self.B * w
        frame = np.clip(frame + b, 0, 255)
        
        # saturation
        sigma = np.random.uniform(1-self.S, 1+self.S)
        frame = sigma * frame + (1-sigma) / 3 * frame.sum(axis=2, keepdims=True)
        frame = np.clip(frame, 0, 255)
        
        return frame, spectrogram
    
class ToTensor(object):
    """Convert ndarrays to Tensors."""
    
    def __call__(self, sample):
        frame, spectrogram = sample
        F, T = spectrogram.shape

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        frame = frame.transpose((2, 0, 1))
        
        # now specs are of size (Freq, Time) 2D but has to be 3D
        spectrogram = spectrogram.reshape(1, F, T)

        return torch.from_numpy(frame), torch.from_numpy(spectrogram)

In [6]:
## TRY TO ADD DROPOUT

class FaceSubnet(nn.Module):

    def __init__(self, seed=13):
        super(FaceSubnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2, padding=3)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
        # only after conv layers
        self.bn1 = nn.BatchNorm2d(num_features=96)
        self.bn2 = nn.BatchNorm2d(num_features=256)
        self.bn3 = nn.BatchNorm2d(num_features=256)
        self.bn4 = nn.BatchNorm2d(num_features=256)
        self.bn5 = nn.BatchNorm2d(num_features=256)
        
        self.fc6 = nn.Linear(in_features=256 * 7 * 7, out_features=4096)
        self.fc7 = nn.Linear(in_features=4096, out_features=1024)
        self.fc8 = nn.Linear(in_features=1024, out_features=256)
        
        self.mpool = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        B, C, H, W = x.size()
        
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.mpool(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.mpool(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.mpool(x)
        
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc6(x))
        x = self.relu(self.fc7(x))
        x = self.fc8(x)
        
        return F.normalize(x)

## TRY TO REMOVE DROPOUT

class VoiceSubnet(nn.Module):

    def __init__(self, seed=13):
        super(VoiceSubnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=7, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        
        # only after conv layers
        self.bn1 = nn.BatchNorm2d(num_features=96)
        self.bn2 = nn.BatchNorm2d(num_features=256)
        self.bn3 = nn.BatchNorm2d(num_features=256)
        self.bn4 = nn.BatchNorm2d(num_features=256)
        self.bn5 = nn.BatchNorm2d(num_features=256)
        self.bn6 = nn.BatchNorm2d(num_features=4096)
        
        self.relu = nn.ReLU()
        
        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.mpool5 = nn.MaxPool2d(kernel_size=(5, 3), stride=(3, 2))
        
        # Conv2d with weights of size (H, 1) is identical to FC with H weights
        self.fc6 = nn.Conv2d(in_channels=256, out_channels=4096, kernel_size=(9, 1))
        self.fc7 = nn.Linear(in_features=4096, out_features=1024)
        self.fc8 = nn.Linear(in_features=1024, out_features=256)
        
    def forward(self, x):
        B, C, H, W = x.size()
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.mpool1(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.mpool2(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.mpool5(x)
        
        _, _, _, W = x.size()
        self.apool6 = nn.AvgPool2d(kernel_size=(1, W))
        
        x = self.relu(self.fc6(x))
        x = self.apool6(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc7(x))
        x = self.fc8(x)
        
        return F.normalize(x)
    
class CurriculumMining(nn.Module):

    def __init__(self):
        super(CurriculumMining, self).__init__()
        
    def forward(self, positive_pairs, tau):
        faces, voices = positive_pairs
        B, D = faces.size()
        # calc dist 
        # (X - Y) ^ 2 = X^2 + Y^2 - 2XY
        x = (faces**2).sum(dim=1).view(-1, 1) + (voices**2).sum(dim=1) - 2*faces.matmul(voices.t())
        dists = x.sqrt()
        
        sorted_dist, sorted_idx = torch.sort(dists, dim=1, descending=True)
        Dnj = sorted_dist - dists.diag().view(-1, 1)
        idx_threshold = round(tau * (B-1))
        
        # tricky part
        mask = torch.ones_like(sorted_dist)
        mask[:, idx_threshold+1:] = 0
        mask[Dnj <= 0] = 0
        idx_of_sorted_idx = ((mask).sum(dim=1) - 1).abs().long()
        neg_samples_idx = torch.gather(sorted_idx, dim=1, index=idx_of_sorted_idx.view(B, 1))
        neg_samples_idx = neg_samples_idx.view(B)
        
        neg_samples_idx = torch.randperm(B) # here

        negative_voices = voices[neg_samples_idx]
        
        return faces, negative_voices
    
class LearnablePinsNet(nn.Module):

    def __init__(self):
        super(LearnablePinsNet, self).__init__()
        self.face_subnet = FaceSubnet()
        self.voice_subnet = VoiceSubnet()
        self.curr_mining = CurriculumMining()
        
    def forward(self, frames, specs, tau=None):
        emb_f = self.face_subnet(frames)
        emb_v = self.voice_subnet(specs)
            
        if self.training:
            positive_pairs = emb_f, emb_v
            negative_pairs = self.curr_mining(positive_pairs, tau)
                
            return positive_pairs, negative_pairs
        
        else:
            return emb_f, emb_v

In [7]:
class ContrastiveLoss(nn.Module):

    def __init__(self):
        super(ContrastiveLoss, self).__init__()
        
    def forward(self, positive_pairs, negative_pairs, margin):
        ## POSITIVE PART
        faces, voices = positive_pairs
#         dists_pos = ((faces - voices) ** 2).sum(dim=1).sqrt()
#         pos_part = dists_pos ** 2
        pos_part = ((faces - voices) ** 2).sum(dim=1)
    
        ## NEGATIVE PART
        faces, voices = negative_pairs
        dists_neg = ((faces - voices) ** 2).sum(dim=1).sqrt()
        neg_part = (margin - dists_neg).clamp(0) ** 2
        
        loss4pair = torch.cat([pos_part, neg_part])
        
        ## CALCULATE LOSS
        B, D = faces.size()
        batch_loss = loss4pair.sum() / (B + B)
    
        return batch_loss

In [8]:
class TauScheduler(object):
    '''
    "found that it was effective to increase \tau by 10 percent 
    every two epochs, starting from 30% up until 80%, and keeping 
    it constant thereafter"
    --- So, it is increasing by 10 % every second epoch:
            ⎧tau = tau + tau * 0.1, tau < 0.8, 
            ⎨
            ⎩tau = 0.8, tau > 0.8.
    '''
    
    def __init__(self, lowest, highest):
        self.current = int(lowest * 100)
        self.highest = int(highest * 100)
        self.epoch_num = 0

    def step(self):
            
        if self.epoch_num % 2 == 0 and self.epoch_num > 0:
#                 self.current += 10
            self.current = int(self.current + self.current * 0.1)
        
        if self.current > self.highest:
            self.current = 80
    
        self.epoch_num += 1
        
    def get_tau(self):
        return np.random.uniform() # here
#         return self.current / 100

In [9]:
LOG_PATH = '/home/nvme/logs/LearnablePINs/_overfit_test/'
DATA_PATH = '/home/nvme/data/voxceleb1/'
SPLIT_PATH = os.path.join(DATA_PATH, 'Splits/filtered_voice_set_labels.txt')
FACE_SUBNET_SNAPSHOT_PATH = os.path.join(LOG_PATH, 'face_subnet_snapshot.txt')
VOICE_SUBNET_SNAPSHOT_PATH = os.path.join(LOG_PATH, 'voice_subnet_snapshot.txt')
DEVICES = [1] # here
B = 76 * len(DEVICES)
# https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5
NUM_WORKERS = 4 * len(DEVICES)
MARGIN = 0.6

In [None]:
0.36 / ()

In [10]:
TBoard = tensorboardX.SummaryWriter(log_dir=LOG_PATH)

transform = Compose([
    Normalize(),
    RandomHorizontalFlip(),
    ColorJittering(),
    ToTensor(),
])

test_transform = Compose([
    Normalize(),
    ToTensor(),
])

train = AudioFrameDataset(DATA_PATH, SPLIT_PATH, transform=transform)
trainloader = torch.utils.data.DataLoader(train, batch_size=B, num_workers=NUM_WORKERS, shuffle=True)

net = LearnablePinsNet()

criterion = ContrastiveLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=10 ** (-6/49))
tau_scheduler = TauScheduler(lowest=0.3, highest=0.8)
eval_results = {}

device = "cuda:1"
torch.cuda.set_device(DEVICES[0])
net.to(device);
# net = nn.DataParallel(net, DEVICES) # here

for epoch_num in range(50):
    net.train()
    lr_scheduler.step()
    tau_scheduler.step()

    for iter_num, (frames, specs) in tqdm(enumerate(trainloader)):
        # transfer inputs to a device
        frames, specs = frames.cuda(async=True), specs.cuda(async=True)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        positive_pairs, negative_pairs = net(frames, specs, tau=tau_scheduler.get_tau())

        loss = criterion(positive_pairs, negative_pairs, margin=MARGIN)

        loss.backward()
        optimizer.step()
        
        # Tensorboard
        step_num = epoch_num * len(trainloader) + iter_num
        TBoard.add_scalar('Train/Loss', loss.item(), step_num)
        TBoard.add_scalar('Train/lr', lr_scheduler.get_lr()[0], step_num)
        TBoard.add_scalar('Train/tau', tau_scheduler.get_tau(), step_num)

4it [00:02,  1.33it/s]
4it [00:02,  1.37it/s]
4it [00:02,  1.45it/s]
4it [00:02,  1.45it/s]
4it [00:02,  1.46it/s]
4it [00:02,  1.46it/s]
4it [00:02,  1.38it/s]
4it [00:02,  1.44it/s]
4it [00:02,  1.41it/s]
4it [00:02,  1.39it/s]
4it [00:02,  1.39it/s]
4it [00:02,  1.40it/s]
4it [00:02,  1.67it/s]
4it [00:02,  1.20it/s]
4it [00:02,  1.19it/s]
4it [00:02,  1.19it/s]
4it [00:02,  1.16it/s]
4it [00:03,  1.09it/s]
4it [00:02,  1.24it/s]
4it [00:02,  1.13it/s]
4it [00:02,  1.22it/s]
4it [00:02,  1.20it/s]
4it [00:02,  1.17it/s]
4it [00:02,  1.15it/s]
4it [00:02,  1.18it/s]
4it [00:02,  1.13it/s]
4it [00:02,  1.21it/s]
4it [00:02,  1.16it/s]
4it [00:02,  1.11it/s]
4it [00:02,  1.15it/s]
4it [00:02,  1.18it/s]
4it [00:02,  1.38it/s]
4it [00:02,  1.41it/s]
4it [00:02,  1.37it/s]
4it [00:02,  1.45it/s]
4it [00:02,  1.40it/s]
4it [00:02,  1.40it/s]
4it [00:02,  1.51it/s]
4it [00:02,  1.45it/s]
4it [00:02,  1.46it/s]
4it [00:02,  1.43it/s]
4it [00:02,  1.39it/s]
4it [00:02,  1.51it/s]
4it [00:02,