In [None]:
from train import *

In [None]:
cmd="--loss_type Twin_AC --AC \
--AC_weight 1.0 \
--G_shared \
--n_domain 4 \
--shuffle --batch_size 200 \
--num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 1000 \
--num_D_steps 4 --num_G_steps 1 --G_lr 2e-4 --D_lr 2e-4 \
--source_dataset mnist,mnist_m,svhn,syn_digits --target_dataset mnist_m --num_workers 16 \
--G_ortho 0.0 \
--G_attn 0 --D_attn 0 --G_ch 64 --D_ch 64 \
--G_init N02 --D_init N02 \
--test_every 8000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 2019 \
--ema  --use_ema --ema_start 10000"

In [None]:
# parse command line and run
parser = utils.prepare_parser()
config = vars(parser.parse_args(cmd.split()))

In [None]:
config['resolution'] = 32#utils.imsize_dict[config['dataset']]
config['n_classes'] = 10#utils.nclass_dict[config['dataset']]
config['G_activation'] = utils.activation_dict[config['G_nl']]
config['D_activation'] = utils.activation_dict[config['D_nl']]

config['skip_init'] = True
config = utils.update_config_roots(config)
device = 'cuda:5'

# Seed RNG
utils.seed_rng(config['seed'])

# Prepare root folders if necessary
utils.prepare_root(config)

# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True

# Import the model--this line allows us to dynamically select different files.
model = __import__(config['model'])
experiment_name = (config['experiment_name'] if config['experiment_name']
                   else utils.name_from_config(config))
# Next, build the model
G = model.Generator(**config).to(device)
D = model.Discriminator(**config).to(device)

if config['ema']:
    G_ema = model.Generator(**{**config, 'skip_init':True, 'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
else:
    ema = None

In [None]:
GD = model.G_D(G, D)
state_dict = {'itr': 0, 'epoch': 929, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

config['load_weights'] = '' #"mnist,mnist_m,svhn,syn_digits_mnist_m" #_num_domain: 4_Twin_AC_AC_weight1.0_BigGAN_seed2019_Gch64_Dch64_bs200_nDs4_Glr2.0e-04_Dlr2.0e-04_Gnlrelu_Dnlrelu_GinitN02_DinitN02_Gshared_ema_epoch929"
G_batch_size = max(config['G_batch_size'], config['batch_size'])
utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)
z_, y_, yd_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],config['n_domain'],
                             device=device, fp16=config['G_fp16'])

In [None]:
z_.sample_()
#y_.sample_()
#yd_.sample_()
y = torch.tensor([i  for k in range(4) for i in range(10)]).to(device)
yd = torch.tensor([int(k) for  k in range(4) for j in range(10) ]).to(device)

#out = G(z_, G.shared(y_), G.shared_d(yd_))
out = G(z_[:40], G.shared(y), G.shared_d(yd))

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
figure(figsize=(13, 10))
for j in range(4):
    for i in range(10):
        plt.subplot(8,10,  j*10 + i+1)
        jj = j
        if(j == 1):
            jj = 3
        if(j == 3):
            jj = 1
        img = out[jj*10 + i] #.cpu().detach().numpy()
        #img = (img*-1)+1
        #plt.title((y[j*10 + i].detach().cpu().item(), yd[j*10 + i].detach().cpu().item()))
        img = (img - img.min())/(img.max() - img.min())
        img = img.detach().cpu().numpy().transpose(1, 2, 0)
        if(jj == 0):
            img = 1-img
        plt.axis("off")
        plt.imshow(img)
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig("G_syn.png")

In [None]:
from IPython.display import FileLink
FileLink("real.png")

In [None]:
self.G.shared(gy),self.G.shared_d(gyd)

In [None]:

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)

    if config['cross_replica']:
      patch_replication_callback(GD)

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  print(test_metrics_fname)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname, 
                                 reinitialize=(not config['resume']))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])

  transforms_train = transforms.Compose([transforms.Resize(config['resolution']),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])

  print(config['base_root'])
  data_set = source_domain_numpy(root=config['base_root'], root_list=config['source_dataset'], transform=transforms_train)
  # loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
  #                                     'start_itr': state_dict['itr']})
  loaders = torch.utils.data.DataLoader(data_set, batch_size=D_batch_size, shuffle=True,
           num_workers=config['num_workers'],
           pin_memory=True,
           worker_init_fn=np.random.seed,drop_last=True)

  test_set_s = domain_test_numpy(root= config['base_root'],
                               root_t=config['target_dataset'], transform=transforms_train)
  test_loader_s = torch.utils.data.DataLoader(test_set_s, batch_size=D_batch_size, shuffle=False,
                                            num_workers=config['num_workers'],
                                            pin_memory=True,
                                            worker_init_fn=np.random.seed, drop_last=True)

  test_set_t = domain_test_numpy(root= config['base_root'],root_t=config['target_dataset'],transform=transforms_train)
  test_loader_t = torch.utils.data.DataLoader(test_set_t, batch_size=D_batch_size, shuffle=False,
           num_workers=config['num_workers'],
           pin_memory=True,
           worker_init_fn=np.random.seed,drop_last=True)

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  z_, y_, yd_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],config['n_domain'],
                             device=device, fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y, fixed_yd = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'],config['n_domain'], device=device,
                                       fp16=config['G_fp16'])
  fixed_z.sample_()
  fixed_y.sample_()
  fixed_yd.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_,yd_,
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  # sample = functools.partial(utils.sample,
  #                             G=(G_ema if config['ema'] and config['use_ema']
  #                                else G),
  #                             z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  # Train for specified number of epochs, although we mostly track G iterations.
  for epoch in range(state_dict['epoch'], config['num_epochs']):
    if epoch%10 == 0:
        test_acc(D, test_loader_s, epoch, "s")
        test_acc(D, test_loader_t, epoch, "t")
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
      pbar = utils.progress(loaders,displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders)
    for i, (x_s, y, yd) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()
      if config['D_fp16']:
        x_s,x_t, y = x_s.to(device).half(), x_t.to(device).half(), y.to(device)
      else:
        x_s, y, yd = x_s.to(device), y.to(device),yd.to(device)
      metrics = train(x_s, y, yd)
      # train_log.log(itr=int(state_dict['itr']), **metrics)
      
      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['pbar'] == 'mine':
          print(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]), end=' ')
          wandb_metric = {key: metrics[key] if type(metrics[key]) == float else metrics[key].item() for key in metrics.keys()}
          wandb.log(wandb_metric)

      # Save weights and copies as configured at specified interval
      if not (state_dict['itr'] % config['save_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_,yd_, fixed_z, fixed_y,fixed_yd,
                                  state_dict, config, f"{experiment_name}_epoch{epoch}")

      # Test every specified interval
      if not (state_dict['itr'] % config['test_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        # train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
        #                get_inception_metrics, experiment_name, test_log)
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
