In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

from core.config import configs

configs.DATASET.ROOT = r'D:\Code\Project\NeuralTexture_gan\data'
configs.DATASET.MODE = 'ONE_VIEW'

LIGHT_NUM = 384

In [17]:
class DotModel(nn.Module):
    def __init__(self, dim):
        """
        Args:
        - dim: light num
        """
        super(DotModel, self).__init__()
        self.dim = dim
        self.lighting_pattern = nn.Parameter(torch.Tensor(1, dim))     # (1, dim)
        self.reset_parameters()

    def reset_parameters(self):
        init.xavier_uniform_(self.lighting_pattern.data)

    def forward(self, x):
        """
        Args:
        - x: B × dim × H × W  lumitexel
        """
        batch = x.size(0)
        height = x.size(2)
        width = x.size(3)
        
        lp = F.normalize(self.lighting_pattern)                         # (1, dim)

        lp = lp.view(batch, self.dim, 1, 1)                             # (B, dim, 1, 1)

        lp = lp.repeat(1, 1, height, width)                             # (B, dim, H, W)

        res = torch.sum(x * lp, 1)                                      # (B, H, W)

        return res

In [18]:
# from core.datasets.egg import EggDataset

# egg_dataset = EggDataset(root=configs.DATASET.ROOT, is_train=True)

# egg_loader = torch.utils.data.DataLoader(
#         egg_dataset,
#         batch_size=1,
#         shuffle=False,
#         pin_memory=True
#     )

# lumitexel = torch.zeros(1, LIGHT_NUM, 256, 256)
# for i, (_, gt, masks, _, _, _) in enumerate(egg_loader):
#     lumitexel[0, i] = gt[0, 0]

def print_lumi(lumi):
    """
    lumi: (1, 384)
    """
    print ("=> lumitexel:")
    print (lumi)

    # print (torch.argmax(lumitexel[0, 0, 120, :]))

    print ("=> lumitexel normalize:")
    print (F.normalize(lumi))

    print ("=> max dot value:")
    print (torch.sum(lumi * F.normalize(lumi)))

    print ("=> norm:")
    print (torch.norm(lumi))

def compare_lumi(lumi1, lumi2):
    print ("=> lumi1 - lumi2:")
    print (lumi1 - lumi2)

    print ("=> avg |lumi1 - lumi2|: ")
    print (torch.sum(torch.abs(lumi1 - lumi2)) / LIGHT_NUM)

    print ("=> normalize(lumi1) - normalize(lumi2):")
    print (F.normalize(lumi1) - F.normalize(lumi2))

    print ("=> avg |normalize(lumi1) - normalize(lumi2)|: ")
    print (torch.sum(torch.abs(F.normalize(lumi1) - F.normalize(lumi2))) / LIGHT_NUM)

lumitexel = torch.load('data\\lumitexel')

In [19]:
lumi = lumitexel[0:1, :, 120, 81]
print_lumi(lumi)

=> lumitexel:
tensor([[0.8475, 0.6841, 0.5543, 0.3908, 0.9925, 1.2348, 1.1039, 0.8991, 1.0527,
         1.1553, 1.1248, 1.0734, 0.5985, 0.7062, 0.7914, 0.7940, 0.1814, 0.2781,
         0.2678, 0.3060, 0.0894, 0.1155, 0.1263, 0.1114, 0.0503, 0.0626, 0.0691,
         0.0683, 0.0315, 0.0414, 0.0437, 0.0465, 0.0510, 0.0316, 0.0243, 0.0177,
         0.0357, 0.0278, 0.0231, 0.0173, 0.1172, 0.0596, 0.0288, 0.0215, 0.0685,
         0.0420, 0.0275, 0.0208, 0.5238, 0.1125, 0.0297, 0.0253, 0.3742, 0.1201,
         0.0300, 0.0232, 0.2523, 0.0373, 0.0294, 0.0268, 0.5069, 0.0650, 0.0299,
         0.0250, 0.0271, 0.0642, 0.1354, 0.2045, 0.0288, 0.1718, 0.5670, 0.7989,
         0.0324, 0.1481, 0.8127, 2.2488, 0.0282, 0.0863, 0.3127, 0.9179, 0.0246,
         0.0465, 0.1081, 0.1913, 0.0214, 0.0312, 0.0535, 0.0857, 0.0177, 0.0243,
         0.0356, 0.0482, 0.0155, 0.0190, 0.0259, 0.0313, 0.0601, 0.0640, 0.0632,
         0.0553, 0.0382, 0.0412, 0.0382, 0.0347, 0.2757, 0.3427, 0.3624, 0.2815,
         0.105

In [14]:
device = 'cuda'

model = DotModel(LIGHT_NUM).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

x = lumi.view(1, 384, 1, 1).to(device)

for i in range(20000):
    res = model(x)

    sum_loss = -res
    reg_loss = 0

    for name, param in model.named_parameters():
        if 'bn' not in name:
            reg_loss += torch.norm(param)

    loss = sum_loss + reg_loss * 0.1

    loss.backward()

    optimizer.step()
    
    if i % 5000 == 0:
        print (sum_loss.item(), reg_loss.item())

0.16600388288497925 1.4053382873535156
-8.516544342041016 84.35162353515625
-8.355779647827148 47.443092346191406
-8.489959716796875 64.44571685791016


In [15]:
print_lumi(model.lighting_pattern)

=> lumitexel:
Parameter containing:
tensor([[ 8.3523e+00,  6.3180e+00,  4.0549e+00,  3.9473e+00,  8.2981e+00,
          1.1965e+01,  1.0400e+01,  8.2942e+00,  8.6976e+00,  1.0890e+01,
          1.0285e+01,  8.8395e+00,  5.4935e+00,  5.5565e+00,  7.6437e+00,
          6.4284e+00,  2.0905e+00,  3.2694e+00,  1.5141e+00,  3.3032e+00,
         -6.3660e-02,  6.5141e-01,  1.5357e+00,  1.1597e+00,  1.3735e+00,
         -9.9379e-02, -3.0554e-01,  4.3658e-01,  4.3754e-01, -4.5477e-01,
         -4.1613e-01,  1.9795e-01,  2.8561e-01,  1.7736e-01,  6.8010e-01,
         -6.9359e-01,  1.2839e+00,  6.7771e-01,  1.9029e-01,  1.2795e-01,
          1.8062e+00,  3.3503e-01, -2.1545e-01, -3.1608e-01,  1.5746e+00,
          4.6990e-01, -7.0768e-01, -3.6754e-01,  4.9572e+00,  1.4982e+00,
          9.3023e-01, -5.9271e-01,  2.5384e+00,  5.0043e-01, -3.0198e-01,
          5.2815e-01,  2.6456e+00,  5.7192e-01,  1.0093e+00, -2.8403e-01,
          4.9226e+00, -2.8337e-01,  8.7719e-01,  1.2539e-01,  1.6022e-01,
  

In [11]:
compare_lumi(lumi.to(device), model.lighting_pattern)

=> lumi1 - lumi2:
tensor([[-4.0419e-01, -7.2886e-01, -7.6498e-01, -6.1447e-01, -9.9986e-01,
         -1.1813e+00, -1.1862e+00, -5.1457e-01, -1.2317e+00, -6.7640e-01,
         -1.1983e+00, -8.6516e-01, -2.4410e-01, -7.3597e-01, -8.2857e-01,
         -2.5738e-01,  1.3989e-01, -1.1187e-01, -2.5287e-01, -4.8886e-01,
         -8.3984e-02, -4.8046e-01, -2.2612e-01, -3.8661e-01,  1.5783e-03,
         -2.2056e-01,  3.7393e-01,  9.9088e-02, -2.3802e-01, -3.5442e-01,
         -3.0858e-01, -5.9328e-02, -3.1072e-01,  2.2605e-01, -1.5767e-01,
         -1.9574e-01, -1.0968e-01,  2.3646e-01, -1.7410e-01, -2.1113e-01,
         -2.7204e-01, -8.7618e-02,  1.7227e-02,  2.6939e-01, -1.4119e-01,
          2.9770e-01,  2.1867e-01,  1.7764e-01, -2.9051e-01,  8.4656e-02,
         -8.9546e-03,  2.3309e-01, -9.8065e-02, -4.2104e-01,  1.6614e-01,
         -1.0399e-01, -1.6724e-01,  3.5915e-01, -4.6017e-01, -2.5536e-01,
         -3.7265e-01, -4.2823e-01,  1.8991e-01, -4.2397e-01,  1.9825e-01,
          1.2785e-01