In [83]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
%matplotlib inline
import glob
import pickle
from sklearn.preprocessing import MultiLabelBinarizer
from skmultilearn.model_selection.iterative_stratification import iterative_train_test_split
import sys
import random

In [84]:
with open(f'labels/l_to_i.pyb', "rb") as fp:   #Pickling
    label_image_pair = pickle.load(fp)
with open(f'labels/i_l.pyb', "rb") as fp:   #Pickling
    image_label_pair = pickle.load(fp)

In [85]:
def gen_img_path(fnames):
    
    global img_path_prefix
    paths = []
    prefix = img_path_prefix 
    suffix = '.jpg'
    for f in fnames:
        paths.append(f'{prefix}{f}{suffix}')
    return paths    
        

In [86]:
def gen_y_vals(img_names,del_axis = None):
    
    global image_label_pair
    
    labels_trained = []
    for i in img_names:
        labels_trained.append([int(l) for l in image_label_pair[i]])
    X,y = np.array(img_names).reshape(-1,1),labels_trained
#     print(X,y)
    one_hot = MultiLabelBinarizer(classes =np.arange(1,14))
    y_vals = one_hot.fit_transform(y)
    if del_axis is not None:
        y_vals=np.delete(y_vals,del_axis,axis = 1)
    return y_vals
   
        

In [87]:
def add_label(fn,label):
    seg_path,seg_name =  '/'.join(fn.split('/')[:-1]),fn.split('/')[-1][:-4]
    return os.path.join(seg_path,f'{seg_name}_{label}.jpg')

In [88]:
def gen_seg_path(fnames,y_vals,prefix,percent = 1):
    
    
    X,y = fnames.reshape(-1,1),y_vals
    _, _, seg_train, _ = iterative_train_test_split(X, y, test_size = percent)
    print(len(seg_train))
    seg_paths = []
    suffix = '.jpg'
    for f,y in zip(fnames,y_vals):
        paths = glob.glob(f'{prefix}{f}_*')
        if f not in seg_train:
            paths = []
        seg_paths.append(paths)
#         print(y,seg_paths)
#         break
    return seg_paths

In [89]:
train_data =  [1, 2, 3, 5, 6, 8, 11, 12, 14, 16, 17, 18, 19, 21, 22, 24, 27, 28, 30, 31, 32, 33, 34, 35, 37, 38, 41, 43, 44, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 65, 66, 67, 70, 72, 73, 75, 76, 78, 79, 81, 82, 83, 85, 86, 88, 89, 91, 92, 94, 95, 96, 98, 99, 101, 102, 104, 105, 107, 108, 111, 112, 113, 114, 115, 117, 120, 121, 123, 124, 126, 127, 128, 129, 130, 131, 133, 134, 136, 137, 139, 140, 142, 143, 144, 145, 146, 147, 149, 152, 153, 155, 156, 158, 159, 160, 161, 162, 163, 165, 166, 168, 171, 172, 174, 175, 176, 177, 178, 181, 182, 184, 185, 187, 188, 190, 191, 192, 193, 195, 197, 198, 200, 201, 203, 204, 206, 207, 208, 210, 211, 213, 214, 216, 217, 219, 220, 222, 223, 224, 225, 226, 227, 229, 230, 232, 233, 235, 236, 238, 240, 241, 242, 243, 245, 246, 248, 251, 252, 254, 255, 256, 257, 258, 259, 261, 262, 264, 265, 268, 270, 271, 272, 273, 274, 275, 277, 280, 281, 283, 284, 286, 287, 288, 289, 291, 293, 294, 296, 299, 300, 302, 303, 304, 305, 306, 307, 310, 312, 313, 315, 316, 318, 319, 321, 322, 323, 325, 326, 328, 329, 331, 332, 334, 335, 336, 337, 339, 341, 342, 344, 345, 347, 348, 350, 351, 352, 353, 354, 355, 357, 358, 360, 361, 363, 366, 367, 368, 369, 370, 371, 373, 374, 376, 377, 380, 382, 383, 384, 385, 386, 387, 389, 390, 392, 393, 396, 398, 399, 400, 401, 402, 403, 405, 406, 408, 409, 411, 412, 414, 416, 417, 418, 419, 421, 422, 424, 425, 427, 428, 430, 431, 432, 433, 434, 435, 437, 438, 440, 441, 443, 444, 446, 447, 448, 449, 450, 451, 453, 454, 456, 457, 459, 460, 462, 463, 464, 465, 466, 467, 469, 470, 472, 473, 475, 476, 478, 479, 480, 481, 482, 483, 485, 486, 488, 489, 491, 492, 494, 495, 496, 497, 498, 499, 501, 502, 504, 505, 507, 508, 510, 512, 513, 514, 515, 517, 518, 520, 521, 523, 524, 526, 527, 528, 529, 530, 531, 533, 534, 536, 537, 539, 540, 542, 543, 544, 545, 546, 547, 549, 550, 552, 553, 555, 558, 559, 561, 562, 563, 565, 566, 568, 569, 571, 572, 574, 575, 576, 577, 578, 579, 581, 582, 584, 585, 587, 588, 590, 591, 592, 593, 594, 595, 597, 598, 600, 601, 603, 604, 606, 607, 608, 609, 610, 611, 613, 616, 617, 619, 620, 622, 623, 624, 625, 626, 627, 629, 630, 632, 633, 635, 636, 638, 640, 641, 642, 643, 645, 646, 648, 649, 651, 652, 654, 655, 656, 657, 658, 659, 661, 662, 664, 665, 667, 668, 671, 672, 673, 674, 675, 678, 680, 681, 683, 684, 686, 687, 688, 689, 690, 691, 693, 694, 696, 697, 699, 702, 703, 704, 705, 706, 707, 709, 710, 712, 713, 715, 716, 718, 719, 720, 721, 722, 723, 725, 726, 728, 729, 731, 734, 735, 737, 738, 739, 741, 742, 744, 745, 747, 748, 750, 751, 752, 753, 754, 755, 757, 760, 761, 763, 764, 766, 767, 768, 769, 770, 773, 774, 777, 779, 780, 782, 783, 784, 785, 786, 787, 789, 790, 792, 793, 796, 798, 800, 801, 802, 805, 806, 808, 809, 811, 812, 814, 815, 816, 818, 819, 821, 822, 824, 825, 827, 828, 830, 831, 832, 833, 834, 837, 838, 840, 841, 843, 844, 846, 847, 848, 849, 850, 851, 853, 854, 855, 857, 859, 860, 862, 863, 864, 865, 866, 867, 870, 871, 872, 875, 876, 878, 879, 880, 881, 882, 883, 886, 888, 889, 891, 892, 894, 895, 896, 897, 898, 899, 901, 902, 904, 907, 908, 910, 911, 912, 913, 914, 915, 917, 918, 920, 921, 923, 924, 926, 927, 928, 929, 930, 931, 933, 934, 936, 937, 939, 940, 942, 944, 945, 946, 947, 949, 950, 952, 953, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 968, 969, 971, 972, 974, 975, 976, 977, 978, 979, 981, 982, 984, 987, 988, 990, 991, 992, 995, 997, 998, 1000]

In [90]:
val_data = [9, 13, 25, 26, 40, 42, 58, 69, 74, 80, 90, 97, 106, 110, 119, 135, 150, 151, 167, 169, 179, 183, 194, 199, 209, 215, 231, 239, 247, 249, 263, 267, 278, 279, 290, 295, 308, 309, 320, 324, 338, 340, 356, 364, 372, 381, 395, 410, 415, 426, 442, 458, 474, 490, 506, 511, 522, 538, 554, 560, 612, 614, 637, 639, 669, 670, 679, 700, 730, 736, 756, 758, 775, 776, 794, 799, 810, 817, 835, 836, 852, 856, 869, 885, 887, 893, 903, 905, 935, 948, 967, 973, 983, 985, 993, 994]

In [91]:
temp = [600, 601, 603, 604, 606, 607, 608, 609, 610, 611, 613, 616, 617, 619, 620, 622, 623, 624, 625, 626, 627, 629, 630, 632, 633, 635, 636, 638, 640, 641, 642, 643, 645, 646, 648, 649, 651, 652, 654, 655, 656, 657, 658, 659, 661, 662, 664, 665, 667, 668, 671, 672, 673, 674, 675, 678, 680, 681, 683, 684, 686, 687, 688, 689, 690, 691, 693, 694, 696, 697, 699, 702, 703, 704, 705, 706, 707, 709, 710, 712, 713, 715, 716, 718, 719, 720, 721, 722, 723, 725, 726, 728, 729, 731, 734, 735, 737, 738, 739, 741, 742, 744, 745, 747, 748, 750, 751, 752, 753, 754, 755, 757, 760, 761, 763, 764, 766, 767, 768, 769, 770, 773, 774, 777, 779, 780, 782, 783, 784, 785, 786, 787, 789, 790, 792, 793, 796, 798, 800, 801, 802, 805, 806, 808, 809, 811, 812, 814, 815, 816, 818, 819, 821, 822, 824, 825, 827, 828, 830, 831, 832, 833, 834, 837, 838, 840, 841, 843, 844, 846, 847, 848, 849, 850, 851, 853, 854, 855, 857, 859, 860, 862, 863, 864, 865, 866, 867, 870, 871, 872, 875, 876, 878, 879, 880, 881, 882, 883, 886, 888, 889, 891, 892, 894, 895, 896, 897, 898, 899, 901, 902, 904, 907, 908, 910, 911, 912, 913, 914, 915, 917, 918, 920, 921, 923, 924, 926, 927, 928, 929, 930, 931, 933, 934, 936, 937, 939, 940, 942, 944, 945, 946, 947, 949, 950, 952, 953, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 968, 969, 971, 972, 974, 975, 976, 977, 978, 979, 981, 982, 984, 987, 988, 990, 991, 992, 995, 997, 998, 1000]
len(temp)

281

In [92]:
test_data = [4, 7, 10, 15, 20, 23, 29, 36, 39, 45, 52, 55, 61, 68, 71, 77, 84, 87, 93, 100, 103, 109, 116, 118, 122, 125, 132, 138, 141, 148, 154, 157, 164, 170, 173, 180, 186, 189, 196, 202, 205, 212, 218, 221, 228, 234, 237, 244, 250, 253, 260, 266, 269, 276, 282, 285, 292, 297, 298, 301, 311, 314, 317, 327, 330, 333, 343, 346, 349, 359, 362, 365, 375, 378, 379, 388, 391, 394, 397, 404, 407, 413, 420, 423, 429, 436, 439, 445, 452, 455, 461, 468, 471, 477, 484, 487, 493, 500, 503, 509, 516, 519, 525, 532, 535, 541, 548, 551, 556, 557, 564, 567, 570, 573, 580, 583, 586, 589, 596, 599, 602, 605, 615, 618, 621, 628, 631, 634, 644, 647, 650, 653, 660, 663, 666, 676, 677, 682, 685, 692, 695, 698, 701, 708, 711, 714, 717, 724, 727, 732, 733, 740, 743, 746, 749, 759, 762, 765, 771, 772, 778, 781, 788, 791, 795, 797, 803, 804, 807, 813, 820, 823, 826, 829, 839, 842, 845, 858, 861, 868, 873, 874, 877, 884, 890, 900, 906, 909, 916, 919, 922, 925, 932, 938, 941, 943, 951, 954, 970, 980, 986, 989, 996, 999]

In [93]:
train_data = np.array(train_data)
test_data = np.array(test_data)
val_data = np.array(val_data)

In [94]:
percent = 0.25* 0.25

In [95]:
#segmentation dataset 1
np.random.seed(1)

img_path_prefix = "/home/hpc/iwfa/iwfa024h/dataset/Images_3000x1500/"
seg_path_prefix = "/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/"

train_pos_path = gen_img_path(train_data[train_data>600])
train_y_vals = gen_y_vals(train_data[train_data>600],del_axis= [2])
train_seg_path = gen_seg_path(train_data[train_data>600],train_y_vals,seg_path_prefix,percent)
train_neg_path = gen_img_path(train_data[train_data<=600])

val_pos_path = gen_img_path(val_data[val_data>600])
val_y_vals = gen_y_vals(val_data[val_data>600],del_axis= [2])
val_seg_path = gen_seg_path(val_data[val_data>600],val_y_vals,seg_path_prefix)
val_neg_path = gen_img_path(val_data[val_data<=600])

test_pos_path = gen_img_path(test_data[test_data>600])
test_y_vals = gen_y_vals(test_data[test_data>600],del_axis= [2])
test_seg_path = gen_seg_path(test_data[test_data>600],test_y_vals,seg_path_prefix)
test_neg_path = gen_img_path(test_data[test_data<=600])
# val_neg_path = gen_full_path(val_data,False,False)

val = [train_pos_path,
       test_pos_path,
       train_seg_path,
       test_seg_path,
       train_neg_path,
       test_neg_path,
       train_y_vals,
       test_y_vals,
       val_pos_path,
       val_seg_path,
       val_neg_path,
       val_y_vals]



24
36
84


In [80]:
with open("splits/PA_M/split_65_D2.pyb","wb") as f:
    pickle.dump(val,f)

In [81]:
with open("splits/PA_M/split_65_D2.pyb","rb") as f:
    val= pickle.load(f)

In [82]:
val[2]

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/632_1.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/635_1.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/638_1.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/641_1.jpg'],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/661_2.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/664_2.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/667_2.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/671_2.jpg'],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/691_3.jpg'],
 [],
 ['/home/hpc/iwfa/iwfa024h/dataset/dataset2/mask_512x256/694_3.jpg'],
 [],
 ['/home/h

In [29]:
with open("/home/hpc/iwfa/iwfa024h/Multi_Label_optical_Segmentation/splits/PA_M/split_0_D1.pyb","rb") as f:
    val= pickle.load(f)

EOFError: Ran out of input

In [31]:
val[2]

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],


In [17]:
imgs = train_data

count_dict = {}
false_ls = []
for i in imgs:
    if i not in image_label_pair.keys():
        continue
    for l in image_label_pair[i]:
        false_ls.append(int(l))
        if l in count_dict.keys():
            count_dict[l]+=1
        else:
            count_dict[l] = 1

print('count classification: ',count_dict)


count classification:  {'3': 595, '1': 35, '2': 35, '4': 35, '5': 35, '6': 34, '7': 35, '8': 36, '9': 35, '10': 35, '11': 35, '12': 35, '13': 35}
