# Through the void

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from IPython.display import display, Math, Latex

manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results

random.seed(manualSeed)
torch.manual_seed(manualSeed)

<torch._C.Generator at 0x7f0a6521f9f0>

*загружаем дата-сет*

In [None]:
dataroot = '/content/img/'
workers = 2
batch_size = 128
image_size = 64
nc = 3
ngf = 64
ndf = 64
num_epochs = 5
lr = 0.0002
beta1 = 0.5
ngpu = 1

In [None]:
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
train_data = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)
val_data = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

None

In [None]:
class AE(nn.Module):
    
        
    def __init__(self, encoder, decoder, ngpu):
        super(AE, self).__init__()
        self.ngpu = ngpu
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
    
        x = self.encoder(x)
        x = self.decoder(x)
        return x




class encoder(torch.nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        
        self.enc = nn.Sequential(
          # input is Z, going into a convolution
            nn.Conv2d( 3, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.Conv2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.Conv2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True)
        )
    def forward(self, x):
        return self.enc(x)
    
class decoder(torch.nn.Module):
    def __init__(self):
        super(decoder, self).__init__()
        
        self.drop_layer = nn.Dropout(p=0.001)
        self.dec =nn.Sequential(
        nn.ConvTranspose2d(ngf * 2, ngf * 4,4, 2, 1, bias=False),
        nn.ReLU(),
        nn.ConvTranspose2d(ngf * 4, ngf * 8,4, 2, 1, bias=False),
        nn.ReLU(),
        nn.ConvTranspose2d(ngf * 8, 3,4, 2, 1, bias=False),
        nn.Tanh()
                        )  
        
    def forward(self, x):
        return self.drop_layer(self.dec(x))


In [None]:
manualSeed = random.randint(1, 10000)
net = AE(encoder(), decoder(),ngpu).to(device)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    net = nn.DataParallel(net, list(range(ngpu)))
  
# optimizer.param_groups[0]['params'][0][0]

In [None]:
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0002)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=(20, 60), gamma=1/2)
img_list=[]

l_1 = list(net.parameters()).copy()



In [None]:
def train(epochs, net, criterion, optimizer, train_loader, val_loader,scheduler=None, verbose=True, save_dir=None):
    
    freq = max(epochs//20,1)
    net.to(device)
    
    iters=0
    for epoch in range(1, epochs+1):
        net.train()

        losses_train = []
        for X, _ in train_loader:
            # Performing one step of minibatch stochastic gradient descent
            
            X = X.to(device)
            out = net(X) 
            loss = criterion(X,out)       

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses_train.append(loss.item())

            if (iters % 500 == 0) or ((epoch == num_epochs-1)):
            
              img_list.append((vutils.make_grid(out, padding=2, normalize=True)).cpu())

        iters += 1

        
        net.eval()
        for X, _ in val_loader:
            losses_val = []
            
            X = X.to(device)
            out = net(X)
            loss_val = criterion(X, out)
            losses_val.append(loss_val.item())

        if scheduler is not None:
            scheduler.step()
        
        if verbose and epoch%freq==0:
            mean_val = sum(losses_val)/len(losses_val)
            mean_train = sum(losses_train)/len(losses_train)

            print('Epoch {}/{} || Loss:  Train {:.4f} | Validation {:.4f}'\
                  .format(epoch, epochs, mean_train, mean_val))

In [None]:
train(5, net, criterion, optimizer, train_data, val_data, scheduler)

Epoch 1/5 || Loss:  Train 0.4537 | Validation 0.3896
Epoch 2/5 || Loss:  Train 0.3176 | Validation 0.2935
Epoch 3/5 || Loss:  Train 0.2479 | Validation 0.2970
Epoch 4/5 || Loss:  Train 0.2099 | Validation 0.1977
Epoch 5/5 || Loss:  Train 0.1846 | Validation 0.1825


In [None]:
l_2 = list(net.parameters()).copy()
l_1[0][1] == l_2[0][1]

tensor([[[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]], device='cuda:0')

In [None]:
param_1 = optimizer.param_groups[0]['params'][0] # это веса == list(net.parameters())

manualSeed = random.randint(1, 10000)
net = AE(encoder(), decoder(),ngpu).to(device)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
img_list=[]
criterion = nn.MSELoss()

train(10, net, criterion, optimizer, train_data, val_data, scheduler)

Epoch 1/10 || Loss:  Train 0.3570 | Validation 0.3503
Epoch 2/10 || Loss:  Train 0.3600 | Validation 0.3503
Epoch 3/10 || Loss:  Train 0.3557 | Validation 0.3504
Epoch 4/10 || Loss:  Train 0.3546 | Validation 0.3504
Epoch 5/10 || Loss:  Train 0.3550 | Validation 0.3505
Epoch 6/10 || Loss:  Train 0.3560 | Validation 0.3506
Epoch 7/10 || Loss:  Train 0.3522 | Validation 0.3508
Epoch 8/10 || Loss:  Train 0.3543 | Validation 0.3510
Epoch 9/10 || Loss:  Train 0.3559 | Validation 0.3513
Epoch 10/10 || Loss:  Train 0.3567 | Validation 0.3515


# Laplacian Pyramid loss

In [None]:
def gauss_kernel(size=5, device=torch.device('cpu'), channels=3):
    kernel = torch.tensor([[1., 4., 6., 4., 1],
                           [4., 16., 24., 16., 4.],
                           [6., 24., 36., 24., 6.],
                           [4., 16., 24., 16., 4.],
                           [1., 4., 6., 4., 1.]])
    kernel /= 256.
    kernel = kernel.repeat(channels, 1, 1, 1)
    kernel = kernel.to(device)
    return kernel

def downsample(x):
    return x[:, :, ::2, ::2]

def upsample(x):
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
    cc = cc.permute(0,1,3,2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2, device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
    x_up = cc.permute(0,1,3,2)
    return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1], device=x.device))

def conv_gauss(img, kernel):
    img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
    out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
    return out

def laplacian_pyramid(img, kernel, max_levels=3):
    current = img
    pyr = []
    for level in range(max_levels):
        filtered = conv_gauss(current, kernel)
        down = downsample(filtered)
        up = upsample(down)
        diff = current-up
        pyr.append(diff)
        current = down
    return pyr

class LapLoss(torch.nn.Module):
    def __init__(self, max_levels=3, channels=3, device=torch.device('cuda')):
        super(LapLoss, self).__init__()
        self.max_levels = max_levels
        self.gauss_kernel = gauss_kernel(channels=channels, device=device)
        
    def forward(self, input, target):
        pyr_input  = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
        pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
        return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

In [None]:
manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

net = AE(encoder(), decoder(),ngpu).to(device)

criterion = LapLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0002) 
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=(20, 60), gamma=0.5)
img_list=[]
train(10, net, criterion, optimizer, train_data, val_data, scheduler)

Epoch 1/10 || Loss:  Train 0.3070 | Validation 0.2929
Epoch 2/10 || Loss:  Train 0.2907 | Validation 0.2902
Epoch 3/10 || Loss:  Train 0.2765 | Validation 0.2715
Epoch 4/10 || Loss:  Train 0.2322 | Validation 0.2148
Epoch 5/10 || Loss:  Train 0.2004 | Validation 0.1915
Epoch 6/10 || Loss:  Train 0.1820 | Validation 0.1848
Epoch 7/10 || Loss:  Train 0.1688 | Validation 0.1707
Epoch 8/10 || Loss:  Train 0.1570 | Validation 0.1609
Epoch 9/10 || Loss:  Train 0.1490 | Validation 0.1488
Epoch 10/10 || Loss:  Train 0.1402 | Validation 0.1375


## Finding  $\; \omega_{i1}, \; \omega_{i2} \;$ from different random starts and save
итог лежит в param_list_end: 

param_list_end[0] = $\omega_{i1}$

param_list_end[1] = $\omega_{i2}$

In [None]:
param_list_begin = [] # веса до обучения, это для проверки
param_list_end = [] # веса после обучения, это для проверки

for _ in range(2):
  manualSeed = random.randint(1, 10000)
  random.seed(manualSeed)
  torch.manual_seed(manualSeed)

  net = AE(encoder(), decoder(),ngpu).to(device)
  criterion = nn.L1Loss()
  optimizer = torch.optim.Adam(net.parameters(), lr=0.0002) 
  scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=(20, 60), gamma=0.5)
  l_1 = list(net.parameters()).copy()
  param_list_begin.append(l_1)
  img_list=[]

  train(10, net, criterion, optimizer, train_data, val_data, scheduler)

  l_2 = list(net.parameters()).copy()
  param_list_end.append(l_2)



Epoch 1/10 || Loss:  Train 0.4610 | Validation 0.3914
Epoch 2/10 || Loss:  Train 0.3281 | Validation 0.3092
Epoch 3/10 || Loss:  Train 0.2588 | Validation 0.3121
Epoch 4/10 || Loss:  Train 0.2213 | Validation 0.2120
Epoch 5/10 || Loss:  Train 0.1940 | Validation 0.1880
Epoch 6/10 || Loss:  Train 0.1805 | Validation 0.1787
Epoch 7/10 || Loss:  Train 0.1761 | Validation 0.1701
Epoch 8/10 || Loss:  Train 0.1577 | Validation 0.1539
Epoch 9/10 || Loss:  Train 0.1536 | Validation 0.1558
Epoch 10/10 || Loss:  Train 0.1435 | Validation 0.1421
True
Epoch 1/10 || Loss:  Train 0.4726 | Validation 0.4245
Epoch 2/10 || Loss:  Train 0.3036 | Validation 0.2787
Epoch 3/10 || Loss:  Train 0.2440 | Validation 0.2435
Epoch 4/10 || Loss:  Train 0.2080 | Validation 0.2033
Epoch 5/10 || Loss:  Train 0.1860 | Validation 0.2050
Epoch 6/10 || Loss:  Train 0.1745 | Validation 0.1730
Epoch 7/10 || Loss:  Train 0.1610 | Validation 0.1593
Epoch 8/10 || Loss:  Train 0.1507 | Validation 0.1592
Epoch 9/10 || Loss:  T

In [None]:
param_list_begin[0][0][0] == param_list_begin[1][0][0]

tensor([[[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]], device='cuda:0')

# Curves

In [None]:
import numpy as np
import math
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter
from torch.nn.modules.utils import _pair
from scipy.special import binom


class Bezier(Module):
    def __init__(self, num_bends):
        super(Bezier, self).__init__()
        self.register_buffer(
            'binom',
            torch.Tensor(binom(num_bends - 1, np.arange(num_bends), dtype=np.float32))
        )
        self.register_buffer('range', torch.arange(0, float(num_bends)))
        self.register_buffer('rev_range', torch.arange(float(num_bends - 1), -1, -1))

    def forward(self, t):
        return self.binom * \
               torch.pow(t, self.range) * \
               torch.pow((1.0 - t), self.rev_range)


class PolyChain(Module):
    def __init__(self, num_bends):
        super(PolyChain, self).__init__()
        self.num_bends = num_bends
        self.register_buffer('range', torch.arange(0, float(num_bends)))

    def forward(self, t):
        t_n = t * (self.num_bends - 1)
        return torch.max(self.range.new([0.0]), 1.0 - torch.abs(t_n - self.range))


class CurveModule(Module):

    def __init__(self, fix_points, parameter_names=()):
        super(CurveModule, self).__init__()
        self.fix_points = fix_points
        self.num_bends = len(self.fix_points)
        self.parameter_names = parameter_names
        self.l2 = 0.0

    def compute_weights_t(self, coeffs_t):
        w_t = [None] * len(self.parameter_names)
        self.l2 = 0.0
        for i, parameter_name in enumerate(self.parameter_names):
            for j, coeff in enumerate(coeffs_t):
                parameter = getattr(self, '%s_%d' % (parameter_name, j))
                if parameter is not None:
                    if w_t[i] is None:
                        w_t[i] = parameter * coeff
                    else:
                        w_t[i] += parameter * coeff
            if w_t[i] is not None:
                self.l2 += torch.sum(w_t[i] ** 2)
        return w_t


class Linear(CurveModule):

    def __init__(self, in_features, out_features, fix_points, bias=True):
        super(Linear, self).__init__(fix_points, ('weight', 'bias'))
        self.in_features = in_features
        self.out_features = out_features

        self.l2 = 0.0
        for i, fixed in enumerate(self.fix_points):
            self.register_parameter(
                'weight_%d' % i,
                Parameter(torch.Tensor(out_features, in_features), requires_grad=not fixed)
            )
        for i, fixed in enumerate(self.fix_points):
            if bias:
                self.register_parameter(
                    'bias_%d' % i,
                    Parameter(torch.Tensor(out_features), requires_grad=not fixed)
                )
            else:
                self.register_parameter('bias_%d' % i, None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.in_features)
        for i in range(self.num_bends):
            getattr(self, 'weight_%d' % i).data.uniform_(-stdv, stdv)
            bias = getattr(self, 'bias_%d' % i)
            if bias is not None:
                bias.data.uniform_(-stdv, stdv)

    def forward(self, input, coeffs_t):
        weight_t, bias_t = self.compute_weights_t(coeffs_t)
        return F.linear(input, weight_t, bias_t)


class Conv2d(CurveModule):

    def __init__(self, in_channels, out_channels, kernel_size, fix_points, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(fix_points, ('weight', 'bias'))
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        for i, fixed in enumerate(self.fix_points):
            self.register_parameter(
                'weight_%d' % i,
                Parameter(
                    torch.Tensor(out_channels, in_channels // groups, *kernel_size),
                    requires_grad=not fixed
                )
            )
        for i, fixed in enumerate(self.fix_points):
            if bias:
                self.register_parameter(
                    'bias_%d' % i,
                    Parameter(torch.Tensor(out_channels), requires_grad=not fixed)
                )
            else:
                self.register_parameter('bias_%d' % i, None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        for i in range(self.num_bends):
            getattr(self, 'weight_%d' % i).data.uniform_(-stdv, stdv)
            bias = getattr(self, 'bias_%d' % i)
            if bias is not None:
                bias.data.uniform_(-stdv, stdv)

    def forward(self, input, coeffs_t):
        weight_t, bias_t = self.compute_weights_t(coeffs_t)
        return F.conv2d(input, weight_t, bias_t, self.stride,
                        self.padding, self.dilation, self.groups)


class _BatchNorm(CurveModule):
    _version = 2

    def __init__(self, num_features, fix_points, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(fix_points, ('weight', 'bias'))
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.l2 = 0.0
        for i, fixed in enumerate(self.fix_points):
            if self.affine:
                self.register_parameter(
                    'weight_%d' % i,
                    Parameter(torch.Tensor(num_features), requires_grad=not fixed)
                )
            else:
                self.register_parameter('weight_%d' % i, None)
        for i, fixed in enumerate(self.fix_points):
            if self.affine:
                self.register_parameter(
                    'bias_%d' % i,
                    Parameter(torch.Tensor(num_features), requires_grad=not fixed)
                )
            else:
                self.register_parameter('bias_%d' % i, None)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            for i in range(self.num_bends):
                getattr(self, 'weight_%d' % i).data.uniform_()
                getattr(self, 'bias_%d' % i).data.zero_()

    def _check_input_dim(self, input):
        raise NotImplementedError

    def forward(self, input, coeffs_t):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None: 
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else: 
                exponential_average_factor = self.momentum
        weight_t, bias_t = self.compute_weights_t(coeffs_t)
        return F.batch_norm(
            input, self.running_mean, self.running_var, weight_t, bias_t,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = metadata.get('version', None)

        if (version is None or version < 2) and self.track_running_stats:
            num_batches_tracked_key = prefix + 'num_batches_tracked'
            if num_batches_tracked_key not in state_dict:
                state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

        super(_BatchNorm, self)._load_from_state_dict(
            state_dict, prefix, metadata, strict,
            missing_keys, unexpected_keys, error_msgs)


class BatchNorm2d(_BatchNorm):

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))


class CurveNet(Module):
    def __init__(self, num_classes, curve, architecture, num_bends, fix_start=True, fix_end=True,
                 architecture_kwargs={}):
        super(CurveNet, self).__init__()
        self.num_classes = num_classes
        self.num_bends = num_bends
        self.fix_points = [fix_start] + [False] * (self.num_bends - 2) + [fix_end]
        
        self.curve = curve
        self.architecture = architecture

        self.l2 = 0.0
        self.coeff_layer = self.curve(self.num_bends)
        self.net = self.architecture(num_classes, fix_points=self.fix_points, **architecture_kwargs)
        self.curve_modules = []
        for module in self.net.modules():
            if issubclass(module.__class__, CurveModule):
                self.curve_modules.append(module)

    def import_base_parameters(self, base_model, index):
        parameters = list(self.net.parameters())[index::self.num_bends]
        base_parameters = base_model.parameters()
        for parameter, base_parameter in zip(parameters, base_parameters):
            parameter.data.copy_(base_parameter.data)

    def import_base_buffers(self, base_model):
        for buffer, base_buffer in zip(self.net._all_buffers(), base_model._all_buffers()):
            buffer.data.copy_(base_buffer.data)

    def export_base_parameters(self, base_model, index):
        parameters = list(self.net.parameters())[index::self.num_bends]
        base_parameters = base_model.parameters()
        for parameter, base_parameter in zip(parameters, base_parameters):
            base_parameter.data.copy_(parameter.data)

    def init_linear(self):
        parameters = list(self.net.parameters())
        for i in range(0, len(parameters), self.num_bends):
            weights = parameters[i:i+self.num_bends]
            for j in range(1, self.num_bends - 1):
                alpha = j * 1.0 / (self.num_bends - 1)
                weights[j].data.copy_(alpha * weights[-1].data + (1.0 - alpha) * weights[0].data)

    def weights(self, t):
        coeffs_t = self.coeff_layer(t)
        weights = []
        for module in self.curve_modules:
            weights.extend([w for w in module.compute_weights_t(coeffs_t) if w is not None])
        return np.concatenate([w.detach().cpu().numpy().ravel() for w in weights])

    def _compute_l2(self):
        self.l2 = sum(module.l2 for module in self.curve_modules)

    def forward(self, input, t=None):
        if t is None:
            t = input.data.new(1).uniform_()
        coeffs_t = self.coeff_layer(t)
        output = self.net(input, coeffs_t)
        self._compute_l2()
        return output


def l2_regularizer(weight_decay):
    return lambda model: 0.5 * weight_decay * model.l2