In [1]:
#%matplotlib inline
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import time
import random
import copy
from copy import deepcopy
import threading

from train_utils import *

import gym3
from procgen import ProcgenGym3Env
import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
device

'cuda'

In [4]:
ltoi = lambda l: int(''.join([str(n) for n in l]))

In [5]:
train_num_levels = 100_000 #500 #1500
train_start_level = 0

all_color_themes = [0,1,2,3,4,5,6,7]
all_road_themes = [0,1,2,3,4,5,6,7]

color_themes_outdist = all_color_themes #[0,1]
color_themes_road_outdist = all_road_themes #[2,3]
outdist_backnoise = 100


color_themes_outdist = ltoi(color_themes_outdist)
color_themes_road_outdist = ltoi(color_themes_road_outdist)

color_themes_indist = color_themes_outdist #ltoi([n for n in all_color_themes if n not in color_themes_outdist])
color_themes_road_indist = color_themes_road_outdist #ltoi([n for n in all_color_themes if n not in color_themes_road_outdist])
indist_backnoise = 0


In [6]:
bs = 64 

env = ProcgenGym3Env(num=bs, env_name="testgame", num_levels=train_num_levels, start_level=train_start_level,
                    color_theme=color_themes_indist, color_theme_road=color_themes_road_indist,
                    background_noise_level=indist_backnoise)

env_outdist = ProcgenGym3Env(num=bs, env_name="testgame", num_levels=train_num_levels, start_level=train_start_level,
                    color_theme=color_themes_outdist, color_theme_road=color_themes_road_outdist,
                    background_noise_level=outdist_backnoise)

building procgen...done


In [7]:
%%time

s = np.array([[.0,.0] for _ in range(bs)], dtype=np.float32)

seq_len = 200

for i in range(seq_len):
    env_indist.act(s)
    rew, obs, first = env_indist.observe()
    img = obs['rgb']
    info = env_indist.get_info()
plt.imshow(img[14])

CPU times: user 2.32 s, sys: 89 ms, total: 2.41 s
Wall time: 572 ms


In [None]:
%%time

s = np.array([[.0,.0] for _ in range(bs)], dtype=np.float32)

seq_len = 200

for i in range(seq_len):
    env_outdist.act(s)
    rew, obs, first = env_outdist.observe()
    img = obs['rgb']
    info = env_outdist.get_info()
plt.imshow(img[14])

In [9]:
dataloader = DataLoader(env=env, bs=bs, seq_len=200)
dataloader_outdist = DataLoader(env=env_outdist, bs=bs, seq_len=200)

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.36 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.35 seconds


In [10]:
front, aux, target = dataloader.get_chunk()

In [11]:
front.shape, aux.shape, target.shape

(torch.Size([200, 64, 3, 64, 64]),
 torch.Size([200, 64, 5]),
 torch.Size([200, 64, 2]))

In [12]:
m = VizCNN(use_rnn=False).to(device);

#m = EfficientNet.from_pretrained('efficientnet-b4', in_channels=3, num_classes=2).to(device)

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.36 seconds


In [13]:
m

VizCNN(
  (pooler): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (act): ReLU()
  (conv_1a): Conv2d(3, 16, kernel_size=(7, 7), stride=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn1_): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_2a): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (bn2a): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2a_): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_2b): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (bn2b): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2b_): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_3a): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3_)

In [14]:
#m.load_state_dict(torch.load("m.torch"))

In [15]:
sum([torch.numel(p) for p in m.parameters()]) / 1000

3193.314

In [16]:
with torch.no_grad():
    hidden = get_hidden(bs)
    out, hidden = m(front[:6].to(device), aux[:6].to(device), hidden)

In [17]:
del front, aux, target, out, hidden

In [18]:
# Function for implementing data augmentation

from torchvision import transforms

deg_to_rad = lambda x: x*0.0174533
crop = transforms.CenterCrop(48)
resize = transforms.Resize(64)
color_jitter = transforms.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.5)

def get_rotated_view(front, aux, rotation):
    front = front.clone()
    SEQ_LEN, BS, C, H, W = front.shape
    ff = front.reshape(SEQ_LEN*BS, C, H, W)
    ff = transforms.functional.rotate(ff, rotation)
    ff = crop(ff)
    ff = resize(ff)
    ##ff = color_jitter(ff)
    ff = ff.reshape(SEQ_LEN, BS, C, H, W)
    _aux = aux.clone()
    _aux[:,:,0] -= deg_to_rad(rotation)
    return ff, _aux


def get_output_shaping_loss(pred): #REVISIT THIS. should not be using these on rotation without more thought.
    target_mean, target_std = 0, .55
    steer_pred_mean, steer_pred_std = pred[:,:,0].flatten().mean(), pred[:,:,0].flatten().std()
    mse = lambda x1, x2: (x1-x2)**2
    l = mse(target_std, steer_pred_std) + mse(target_mean, steer_pred_mean)
    return l

In [19]:
def testdrive(in_distribution=True, calibrate=False, use_training_wheels=False):
    
    TRAINING_WHEELS_WINDOW = 10
    
    t1 = time.time()
    m.eval()
    seq_len = 300
    n_val = 100
    val_env = ProcgenGym3Env(num=n_val, 
                            env_name="testgame", 
                            num_levels=train_num_levels, 
                            start_level=train_start_level if in_distribution else train_start_level+train_num_levels,
                            color_theme=color_themes_indist if in_distribution else color_themes_outdist,
                            color_theme_road=color_themes_road_indist if in_distribution else color_themes_road_outdist,
                            background_noise_level = 0 if in_distribution else 100
                            )
    
    s = np.array([[.0,.0] for _ in range(n_val)], dtype=np.float32)
    reward = 0
    num_collisions = 0
    wp_infractions = 0
    successful_stops = 0
    
    hidden = get_hidden(n_val)
    
    with torch.no_grad():
        for i in range(seq_len):
            val_env.act(s)
            rew, obs, first = val_env.observe()
            reward += rew.sum()
            img = obs['rgb']
            info = val_env.get_info()
            num_collisions += np.array([e['collision'] for e in info]).sum()
            wp_infractions += np.array([e['waypoint_infraction'] for e in info]).sum()
            successful_stops += np.array([e['successful_stop'] for e in info]).sum()
            
            autopilot_control = np.array([[e["autopilot_"+c] for c in control_properties] for e in info])
            
            aux = np.array([[e[a] for a in aux_properties] for e in info])

            front = torch.from_numpy(img.astype(np.float32)/255.).unsqueeze(0).permute(0,1,4,2,3).to(device)
            
            aux = torch.from_numpy(aux.astype(np.float32)).unsqueeze(0).to(device)
            
            front, aux = get_rotated_view(front, aux, 0)
            
            if calibrate:
                s = autopilot_control
            else:
                out, hidden = m(front, aux, hidden, is_src_domain=in_distribution)
                s = out.squeeze(0).squeeze(-1).cpu().numpy()
                s = np.clip(s, -5., 5.)
                
            if use_training_wheels and i < TRAINING_WHEELS_WINDOW:
                s = autopilot_control
        

    reward /= (n_val*seq_len)
    num_collisions /= (n_val*seq_len)
    wp_infractions /= (n_val*seq_len)
    successful_stops /= (n_val*seq_len)
    
    val_env.close()
    m.train()
    print(f"validation took {round(time.time()-t1)} seconds")
    return reward, num_collisions, wp_infractions, successful_stops

In [20]:
testdrive(in_distribution=False, calibrate=False, use_training_wheels=True)

validation took 13 seconds


(0.05196666666666667, 0.0, 0.0, 0.0)

In [21]:
loss_fn = torch.nn.MSELoss().cuda()
scaler = torch.cuda.amp.GradScaler() 
opt = torch.optim.Adam(m.parameters(), lr=3e-4)

In [22]:
import wandb

In [23]:

#wandb.init(project="carlita")

In [24]:
#wandb.watch(m)

In [25]:

def run_epoch(train=True):
    
    start_time = time.time()
    
    # Caching baseline 'perfect' scores for val use. A bit awkward placement
    autopilot_score_baseline_in_dist, autopilot_collisions_in_dist, _, ap_successful_stops_in = testdrive(in_distribution=True, calibrate=True)
    autopilot_score_baseline_out_dist, autopilot_collisions_out_dist, _, ap_successful_stops_out = testdrive(in_distribution=False, calibrate=True)
    
    m.train(train)
    t1 = time.time()
    epoch_loss, preds = [], []
    #bs = random.choice(list(bs_bptt_lookup.keys()))
    bptt = 1 #32 #random.choice([64, 72, 80, 88]) # increasing for 32x32 data #bs_bptt_lookup[bs]
    
    val_cadence = 8
    log_cadence = 1
    
    global dataloader,dataloader_outdist, bs
    log_counter = 0
    
    logger = Logger()
    
    while True:
        chunk = dataloader.get_chunk()
        if not chunk: break
        front_container, aux_container, target_container = chunk
        chunk_len, bs, _, _, _ = front_container.shape
        len_ix = 0
        chunk_loss = []
        
        chunk_outdist = dataloader_outdist.get_chunk()
        front_container_outdist, aux_container_outdist, _ = chunk_outdist
        
        hidden = get_hidden(bs) # Resetting each chunk, ie each 800 steps or so
        
        while len_ix < chunk_len:
                
            #####################
            # Supervised loss on indist
            
            front = front_container[len_ix:len_ix+bptt, :, :, :, :].to(device).half()
            aux = aux_container[len_ix:len_ix+bptt, :, :].to(device).half();
            front, aux = get_rotated_view(front, aux, 0)
            
            target = target_container[len_ix:len_ix+bptt, :, :].to(device).half()

            with torch.cuda.amp.autocast():
                pred, hidden = m(front, aux, hidden, is_src_domain=True)
                
            supervised_loss = loss_fn(target, pred); 
            chunk_loss.append(supervised_loss.item())
            
            #####################
            # Unsupervised loss on outdist
            front_outdist = front_container_outdist[len_ix:len_ix+bptt, :, :, :, :].to(device).half()
            aux_outdist = aux_container_outdist[len_ix:len_ix+bptt, :, :].to(device).half();
            
            rotation = random.uniform(2,12)

            front_base, aux_base = get_rotated_view(front_outdist, aux_outdist, 0)
            front_flip = front_base.flip(-1)
            aux_flip = aux_base.clone(); aux_flip[:,:,0]*=-1
            front_rot_pos, aux_rot_pos = get_rotated_view(front_outdist, aux_outdist, rotation)
            front_rot_neg, aux_rot_neg = get_rotated_view(front_outdist, aux_outdist, -rotation)
            front_rot_pos_2, aux_rot_pos_2 = get_rotated_view(front_outdist, aux_outdist, rotation*2)
            front_rot_neg_2, aux_rot_neg_2 = get_rotated_view(front_outdist, aux_outdist, -rotation*2)

            front_all = torch.cat([
                front_base,
                front_flip,
                front_rot_pos,
                front_rot_neg,
                front_rot_pos_2,
                front_rot_neg_2
            ], dim=0)

            aux_all = torch.cat([
                aux_base,
                aux_flip,
                aux_rot_pos,
                aux_rot_neg,
                aux_rot_pos_2,
                aux_rot_neg_2,
            ], dim=0)

            with torch.cuda.amp.autocast():
                pred_all, _ = m(front_all, aux_all, '', is_src_domain=False)

            pred_base, pred_flip, pred_rot_pos, pred_rot_neg, pred_rot_pos_2, pred_rot_neg_2 = torch.split(pred_all, len(front_base), dim=0)
            
            output_shaping_loss = 0
            output_shaping_loss += get_output_shaping_loss(torch.cat([pred_flip, pred_base], dim=0))
            
            pred_base = pred_base.detach()

            # flip
            flip_loss = loss_fn(pred_base[:,:,0], pred_flip[:,:,0]*-1)

            # Rotation consistency I
            rotation_loss = 0

            diff_pos = pred_rot_pos[:,:,0] - pred_base[:,:,0]
            diff_target = -torch.ones_like(diff_pos) * deg_to_rad(rotation)*1.27
            rotation_loss += loss_fn(diff_pos, diff_target)

            diff_neg = pred_rot_neg[:,:,0] - pred_base[:,:,0]
            diff_target = -torch.ones_like(diff_neg) * deg_to_rad(-rotation)*1.27
            rotation_loss += loss_fn(diff_neg, diff_target)
            
            # Rotation consistency II
            rotation_loss_2 = 0
            rotation_loss_2 += loss_fn((diff_pos*2).detach(), pred_rot_pos_2[:,:,0]-pred_base[:,:,0])
            rotation_loss_2 += loss_fn((diff_neg*2).detach(), pred_rot_neg_2[:,:,0]-pred_base[:,:,0])
            
            ##################
            b = 1
            flip_loss/=b
            rotation_loss/=b
            rotation_loss_2/=b
            output_shaping_loss/=b
            
            #loss = flip_loss + rotation_loss + rotation_loss_2 + supervised_loss + output_shaping_loss
            loss = supervised_loss
            logger.log({
                "supervised_loss":supervised_loss.item(),
                "rotation_loss":rotation_loss.item(),
                "rotation_loss_2":rotation_loss_2.item(),
                "flip_loss":flip_loss.item(),
                "output_shaping_loss":output_shaping_loss.item(),
            })
                
            len_ix += bptt*4
            
            if train:
                # Scales the loss, and calls backward() to create scaled gradients 
                scaler.scale(loss).backward() 
                
                # Unscales the gradients of optimizer's assigned params in-place
                scaler.unscale_(opt)
                # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                torch.nn.utils.clip_grad_norm_(m.parameters(), 5.)
        
                # Unscales gradients and calls or skips optimizer.step() 
                scaler.step(opt) 
                # Updates the scale for next iteration 
                scaler.update() 
                opt.zero_grad()
                
            hidden = (hidden[0].detach(), hidden[1].detach())

        # Save and report at end of each chunk
        t2 = time.time()
        chunk_loss = np.round(np.array(chunk_loss).mean(), 4)
        epoch_loss.append(chunk_loss)
        total_seconds = round(t2 - t1)
        
        if train and log_counter % log_cadence == 0 and log_counter>1: 
            print(logger.finish())
            current_time = time.time()
            torch.save(m.state_dict(), 'm.torch')
            print(f'Done with chunk. Training took {total_seconds} seconds. Chunk loss was {chunk_loss}\n')
            
        if log_counter % val_cadence == 0:
            val_score_in_dist, collisions_in_dist, wp_infractions_in_dist, stops_in = testdrive(in_distribution=True, use_training_wheels=False) 
            val_score_in_dist /= autopilot_score_baseline_in_dist
            #collisions_in_dist /= autopilot_collisions_in_dist
            #stops_in /= ap_successful_stops_in
            
            val_score_out_dist, collisions_out_dist, wp_infractions_out_dist, stops_out = testdrive(in_distribution=False, use_training_wheels=True)
            val_score_out_dist /= autopilot_score_baseline_out_dist
            #collisions_out_dist /= autopilot_collisions_out_dist
            #stops_out /= ap_successful_stops_out
            
            logger.log({
                "score_indist":np.round(val_score_in_dist,2),
                "score_outdist":np.round(val_score_out_dist,2)
            })
            stats = logger.finish()
            print(stats)
            
        t1 = t2
        log_counter+=1
    
    loss = np.array(epoch_loss).mean()
        
    return loss

In [26]:
run_epoch()

validation took 1 seconds
validation took 1 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.54 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.56 seconds
validation took 12 seconds
validation took 13 seconds
{'supervised_loss': 0.39925, 'rotation_loss': 0.08645, 'rotation_loss_2': 0.07104, 'flip_loss': 0.42531, 'output_shaping_loss': 0.17522, 'score_indist': 0.13, 'score_outdist': 0.16}
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.66 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.68 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.58 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.59 seconds
{'supervised_loss': 0.3261, 'rotation_loss': 0.06644, 'rotation_loss_2': 0.03788, 'flip_loss': 0.12354, 'output_shaping_loss': 0.21934}
Done with chunk. Training took 7 seconds. Chunk loss was 0.3552

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.6 se

{'supervised_loss': 0.11197, 'rotation_loss': 0.10542, 'rotation_loss_2': 0.25358, 'flip_loss': 1.10989, 'output_shaping_loss': 0.12249}
Done with chunk. Training took 7 seconds. Chunk loss was 0.112

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.6 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.61 seconds
{'supervised_loss': 0.09818, 'rotation_loss': 0.12387, 'rotation_loss_2': 0.32574, 'flip_loss': 1.4781, 'output_shaping_loss': 0.23263}
Done with chunk. Training took 8 seconds. Chunk loss was 0.0982

validation took 12 seconds
validation took 13 seconds
{'score_indist': 0.49, 'score_outdist': 0.44}
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.63 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.65 seconds
{'supervised_loss': 0.12686, 'rotation_loss': 0.14411, 'rotation_loss_2': 0.28569, 'flip_loss': 1.40507, 'output_shaping_loss': 0.21213}
Done with chunk. Training took 33 seconds. Chunk loss was 0.1269



{'supervised_loss': 0.11953, 'rotation_loss': 0.17619, 'rotation_loss_2': 0.37265, 'flip_loss': 3.06307, 'output_shaping_loss': 0.69199}
Done with chunk. Training took 7 seconds. Chunk loss was 0.1195

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.63 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.64 seconds
{'supervised_loss': 0.08655, 'rotation_loss': 0.20211, 'rotation_loss_2': 0.38849, 'flip_loss': 2.43616, 'output_shaping_loss': 0.47576}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0866

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.61 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
{'supervised_loss': 0.17967, 'rotation_loss': 0.16551, 'rotation_loss_2': 0.35079, 'flip_loss': 1.84211, 'output_shaping_loss': 0.3448}
Done with chunk. Training took 7 seconds. Chunk loss was 0.1797

validation took 13 seconds
validation took 13 seconds
{'score_indist': 0.42, 'score_outdist': 0.4}


Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.76 seconds
{'supervised_loss': 0.10173, 'rotation_loss': 0.26708, 'rotation_loss_2': 0.54333, 'flip_loss': 1.38595, 'output_shaping_loss': 0.25844}
Done with chunk. Training took 7 seconds. Chunk loss was 0.1017

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.59 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
{'supervised_loss': 0.06935, 'rotation_loss': 0.31152, 'rotation_loss_2': 0.61825, 'flip_loss': 1.56996, 'output_shaping_loss': 0.25658}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0694

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.58 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.6 seconds
{'supervised_loss': 0.09999, 'rotation_loss': 0.18275, 'rotation_loss_2': 0.33802, 'flip_loss': 1.34233, 'output_shaping_loss': 0.22879}
Done with chunk. Training took 7 seconds. Chunk loss was 0.1

Queueing chunk of size torch

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.53 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.72 seconds
{'supervised_loss': 0.11728, 'rotation_loss': 0.15786, 'rotation_loss_2': 0.35575, 'flip_loss': 1.28912, 'output_shaping_loss': 0.19189}
Done with chunk. Training took 7 seconds. Chunk loss was 0.1173

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.57 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.59 seconds
{'supervised_loss': 0.0749, 'rotation_loss': 0.18888, 'rotation_loss_2': 0.43675, 'flip_loss': 1.07853, 'output_shaping_loss': 0.11544}
Done with chunk. Training took 8 seconds. Chunk loss was 0.0749

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.59 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.6 seconds
{'supervised_loss': 0.06562, 'rotation_loss': 0.14139, 'rotation_loss_2': 0.35099, 'flip_loss': 0.71839, 'output_shaping_loss': 0.04785}
Done with chunk. 

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.51 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.71 seconds
{'supervised_loss': 0.06163, 'rotation_loss': 0.09753, 'rotation_loss_2': 0.20086, 'flip_loss': 0.51909, 'output_shaping_loss': 0.04934}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0616

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.54 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.65 seconds
{'supervised_loss': 0.06922, 'rotation_loss': 0.0812, 'rotation_loss_2': 0.19095, 'flip_loss': 0.54798, 'output_shaping_loss': 0.06186}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0692

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.55 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.61 seconds
{'supervised_loss': 0.06142, 'rotation_loss': 0.09251, 'rotation_loss_2': 0.22953, 'flip_loss': 0.58913, 'output_shaping_loss': 0.03709}
Done with chunk.

Done with chunk. Training took 33 seconds. Chunk loss was 0.0794

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.79 seconds
{'supervised_loss': 0.07157, 'rotation_loss': 0.07929, 'rotation_loss_2': 0.20335, 'flip_loss': 0.44071, 'output_shaping_loss': 0.01836}
Done with chunk. Training took 8 seconds. Chunk loss was 0.0716

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.57 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.64 seconds
{'supervised_loss': 0.06278, 'rotation_loss': 0.11018, 'rotation_loss_2': 0.25067, 'flip_loss': 0.74392, 'output_shaping_loss': 0.06539}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0628

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.57 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.64 seconds
{'supervised_loss': 0.06674, 'rotation_loss': 0.10788, 'rotation_loss_2': 0.2889, 'fli

validation took 13 seconds
validation took 13 seconds
{'score_indist': 0.87, 'score_outdist': 0.53}
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.53 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.78 seconds
{'supervised_loss': 0.09514, 'rotation_loss': 0.20572, 'rotation_loss_2': 0.47078, 'flip_loss': 1.055, 'output_shaping_loss': 0.17072}
Done with chunk. Training took 33 seconds. Chunk loss was 0.0951

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.71 seconds
{'supervised_loss': 0.0688, 'rotation_loss': 0.28231, 'rotation_loss_2': 0.58694, 'flip_loss': 1.16871, 'output_shaping_loss': 0.20896}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0688

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.58 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.64 seconds
{'supervised_loss': 0.06961, 'rotation_loss': 0.23987,

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.54 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.78 seconds
{'supervised_loss': 0.08068, 'rotation_loss': 0.21648, 'rotation_loss_2': 0.50309, 'flip_loss': 1.90574, 'output_shaping_loss': 0.33286}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0807

validation took 13 seconds
validation took 13 seconds
{'score_indist': 0.89, 'score_outdist': 0.25}
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.57 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.67 seconds
{'supervised_loss': 0.05894, 'rotation_loss': 0.17042, 'rotation_loss_2': 0.37858, 'flip_loss': 1.62229, 'output_shaping_loss': 0.2648}
Done with chunk. Training took 34 seconds. Chunk loss was 0.0589

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.59 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.66 seconds
{'supervised_loss': 0.05976, 'rotation_loss': 0.1616

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.56 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.81 seconds
{'supervised_loss': 0.0652, 'rotation_loss': 0.10202, 'rotation_loss_2': 0.24659, 'flip_loss': 0.72851, 'output_shaping_loss': 0.08959}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0652

Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.76 seconds
{'supervised_loss': 0.0661, 'rotation_loss': 0.1529, 'rotation_loss_2': 0.37246, 'flip_loss': 0.75836, 'output_shaping_loss': 0.12058}
Done with chunk. Training took 7 seconds. Chunk loss was 0.0661

validation took 13 seconds
validation took 14 seconds
{'score_indist': 0.94, 'score_outdist': 0.64}
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.57 seconds
Queueing chunk of size torch.Size([200, 64, 3, 64, 64]) took 0.62 seconds
{'supervised_loss': 0.04045, 'rotation_loss': 0.38383, 

KeyboardInterrupt: 