In [1]:
import os
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from itertools import permutations
from tqdm import tqdm
import numpy as np

DATA_PATH = './data'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [2]:
filters_df = pd.read_csv(os.path.join(DATA_PATH, 'algos.csv'), header=None)
f1 = torch.tensor(filters_df.iloc[0].values.reshape(3, 3), dtype=torch.float32, device=device)
f2 = torch.tensor(filters_df.iloc[1].values.reshape(3, 3), dtype=torch.float32, device=device)

In [3]:
transform = transforms.ToTensor()

class ImageWithTxtDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform or transforms.ToTensor()
        self.files = [f for f in os.listdir(data_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path).convert('L')
        image = self.transform(image)  

        txt_name = img_name.replace('.png', '.txt')
        txt_path = os.path.join(self.data_dir, txt_name)
        label = pd.read_csv(txt_path, header=None, sep=' ').values
        label = torch.tensor(label, dtype=torch.float32).unsqueeze(0) 
        return image, label

In [4]:
dataset = ImageWithTxtDataset(DATA_PATH, transform)
loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
X_batch_cpu, Y_batch_cpu = next(iter(loader))
X_batch_cpu *= 256

dtype = torch.float32
X_batch = X_batch_cpu.to(device=device, dtype=dtype)
Y_batch = Y_batch_cpu.to(device=device, dtype=dtype)

In [5]:
f1_w = f1.unsqueeze(0).unsqueeze(0)  
f2_w = f2.unsqueeze(0).unsqueeze(0)

results = []
labels = ['f1', 'f2', 'f_unknown']
perm = ['f2', 'f_unknown', 'f1']

In [6]:
unk = torch.full((1, 1, 3, 3), 0.0625, dtype=torch.float32, device=device, requires_grad=True)
optimizer = torch.optim.Adam([unk], lr=0.1)

for epoch in tqdm(range(10000), desc=f"Perm {perm}"):
    optimizer.zero_grad()
    cur = X_batch.clone()

    for lbl in perm:
        if lbl == 'f1':
            cur = F.conv2d(cur, f1_w, bias=None, padding=1)
        elif lbl == 'f2':
            cur = F.conv2d(cur, f2_w, bias=None, padding=1)
        else:
            cur = F.conv2d(cur, unk, bias=None, padding=1)
    loss = F.mse_loss(cur, Y_batch)

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, loss={loss.item():.10f}")
    loss.backward()
    optimizer.step()

results.append({'perm': perm, 'loss': loss.item(), 'filter': unk.detach().clone()})
print(f"Permutation {perm}, loss={loss.item():.6f}")

Perm ['f2', 'f_unknown', 'f1']:   1%|          | 79/10000 [00:00<00:27, 354.42it/s]

Epoch 0, loss=4704.5068359375


Perm ['f2', 'f_unknown', 'f1']:  11%|█         | 1083/10000 [00:03<00:25, 350.13it/s]

Epoch 1000, loss=0.0022613765


Perm ['f2', 'f_unknown', 'f1']:  21%|██        | 2083/10000 [00:06<00:22, 345.85it/s]

Epoch 2000, loss=0.0008909916


Perm ['f2', 'f_unknown', 'f1']:  31%|███       | 3087/10000 [00:09<00:20, 344.90it/s]

Epoch 3000, loss=0.0001824660


Perm ['f2', 'f_unknown', 'f1']:  41%|████      | 4085/10000 [00:12<00:16, 354.28it/s]

Epoch 4000, loss=0.0000330281


Perm ['f2', 'f_unknown', 'f1']:  51%|█████     | 5082/10000 [00:15<00:13, 359.50it/s]

Epoch 5000, loss=0.0000059003


Perm ['f2', 'f_unknown', 'f1']:  61%|██████    | 6088/10000 [00:18<00:11, 345.01it/s]

Epoch 6000, loss=0.0000010449


Perm ['f2', 'f_unknown', 'f1']:  71%|███████   | 7085/10000 [00:21<00:08, 349.98it/s]

Epoch 7000, loss=0.0000001900


Perm ['f2', 'f_unknown', 'f1']:  81%|████████  | 8088/10000 [00:24<00:05, 345.57it/s]

Epoch 8000, loss=0.0000173136


Perm ['f2', 'f_unknown', 'f1']:  91%|█████████ | 9086/10000 [00:27<00:02, 351.73it/s]

Epoch 9000, loss=0.0075473450


Perm ['f2', 'f_unknown', 'f1']: 100%|██████████| 10000/10000 [00:30<00:00, 325.50it/s]


Permutation ['f2', 'f_unknown', 'f1'], loss=0.000000


In [7]:
results = sorted(results, key=lambda x: x['loss'])
print(results)

[{'perm': ['f2', 'f_unknown', 'f1'], 'loss': 3.4189044928467638e-09, 'filter': tensor([[[[0.1245, 0.2490, 0.1245],
          [0.2490, 0.4980, 0.2490],
          [0.1245, 0.2490, 0.1245]]]], device='cuda:0')}]


In [10]:
best = results[0]
print(f"Best perm: {best['perm']} with loss {best['loss']:.9f}")
recon_filters = []
for lbl in best['perm']:
    if lbl == 'f1':
        recon_filters.append(f1.cpu())
    elif lbl == 'f2':
        recon_filters.append(f2.cpu())
    else:
        rek = best['filter'].squeeze(0).squeeze(0).cpu()
        recon_filters.append(rek)

# Сохранение reconstructed_algos.csv
recon_array = np.array([f.flatten().numpy() for f in recon_filters])
recon_df = pd.DataFrame(recon_array)
recon_df.to_csv('reconstructed_algos.csv', header=False, index=False)
print('Saved reconstructed_algos.csv')


Best perm: ['f2', 'f_unknown', 'f1'] with loss 0.000000003
Saved reconstructed_algos.csv
