# Introduction

The purpose of this project is to train a convolutional neural network to predict materials of garbage to improve waste sorting. We extracted a dataset from Kaggle made up of 2,467 images of 6 labels. We used Convolutional Neural Network (CNN) models and trained them on 80% of it, validated it on 10%, and tested it on the remaining 10%. Initially, the model performed at an extremely low accuracy, very close to 0%. However, after changing the learning rate and increasing the number of epochs, we managed to get up to 62% accuracy. We will now process the data following these steps: converting images into tensors, saving the tensors into a tensor path, splitting the dataset into training validation and test, and creating a dataset. Once that's done, we will be able to start building the model.

## Setting up the Data

In [1]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import torch
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
import datasets

In [2]:
# Creating a function to convert images into tensors  
def convert_img_path_into_tensor(img_path, tensor_path):
    image = Image.open(img_path)
    transform = transforms.Compose([transforms.PILToTensor()])
    img_tensor = transform(image)
    torch.save(img_tensor, tensor_path) # saving them into a tensor path

In [3]:
# Creating a dictionary of classes 
classes_dict = {
    "cardboard": 0,
    "glass": 1,
    "metal": 2,
    "paper": 3,
    "plastic": 4,
    "trash": 5,
}

In [4]:
# Creating a path 
dataset_path = Path("garbage_dataset/garbage_classification/garbage_classification")

In [5]:
# Applying the function we created to all the images 
for k in classes_dict.keys():
    path = dataset_path / k
    
    # Creating the path where we will save our tensor 
    path_to_save = Path('tensors') / k 
    path_to_save.mkdir(parents=True)
    
    # Nested for loop to iterate through all the images 
    for p in tqdm(path.iterdir()): # using tqdm to show a progress bar 
        convert_img_path_into_tensor(p, path_to_save / p.parts[-1].replace('.jpg', '.pt'))

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [6]:
# Creating two empty lists 
list_of_tensors = []
list_of_labels = [] 

In [7]:
# Loading the tensors through a for loop 
for path in Path('tensors').iterdir():
    k = path.parts[-1]
    for p in path.iterdir():
        tensor = torch.load(p)
        list_of_tensors.append(tensor) 
        list_of_labels.append(classes_dict[k])

In [8]:
# Creating a dataset for tensors and labels 
dataset = datasets.Dataset.from_dict(
    {
        "images": list_of_tensors,
        "labels": list_of_labels,
    }
)

In [9]:
# Saving the dataset on the laptop 
dataset.save_to_disk("processed_garbage_dataset")

Saving the dataset (0/4 shards):   0%|          | 0/2527 [00:00<?, ? examples/s]

In [13]:
# Shuffling the indices of the dataset 
idxs = [1960, 249, 617, 1534, 694, 1813, 1914, 933, 78, 1290, 544, 2184, 972, 559, 2433, 99, 994, 1041, 684, 2152, 1521, 2517, 176, 1668, 1532, 532, 2446, 1027, 203, 1460, 1728, 522, 1392, 1987, 1434, 2052, 659, 2342, 36, 1663, 332, 470, 306, 1118, 2458, 304, 1432, 719, 830, 297, 530, 150, 2338, 2220, 1840, 1544, 1832, 1968, 1586, 1397, 142, 1437, 2323, 1705, 453, 1774, 1426, 1067, 1889, 188, 1773, 93, 1619, 1517, 1546, 1206, 1738, 683, 1144, 1312, 2468, 1457, 11, 799, 2154, 580, 363, 2425, 2018, 2391, 1819, 1114, 276, 1922, 2456, 1675, 665, 2416, 794, 1991, 977, 2296, 699, 687, 1722, 1450, 696, 415, 1569, 735, 356, 2294, 1972, 2258, 926, 2437, 1243, 1215, 889, 1789, 27, 2402, 1537, 1745, 2213, 1277, 1229, 2026, 40, 656, 1575, 822, 569, 833, 702, 195, 372, 919, 30, 556, 1234, 120, 1910, 216, 2411, 1887, 1653, 1068, 2518, 104, 79, 309, 731, 358, 2212, 260, 2360, 1860, 1162, 257, 979, 318, 180, 1407, 1332, 550, 1853, 2046, 2107, 2241, 354, 1909, 1513, 1514, 1689, 1648, 262, 408, 660, 2001, 894, 1032, 2051, 1953, 187, 107, 495, 2239, 2083, 2097, 242, 989, 205, 2171, 2482, 2047, 1525, 1782, 871, 986, 252, 2135, 2165, 811, 1978, 1476, 1160, 554, 1999, 551, 851, 73, 1153, 172, 829, 1724, 1316, 1106, 1282, 942, 698, 2340, 1703, 1626, 804, 508, 2393, 364, 1400, 1404, 2100, 1462, 251, 1484, 652, 2434, 355, 1858, 2022, 2234, 957, 5, 1602, 133, 1078, 1422, 2384, 421, 2454, 1962, 797, 47, 62, 98, 753, 1822, 2501, 1419, 2495, 952, 164, 50, 1298, 92, 1235, 1190, 1186, 2078, 1563, 3, 1448, 467, 782, 487, 1871, 653, 1314, 506, 101, 2140, 2111, 1599, 610, 1395, 2273, 431, 1249, 1057, 14, 307, 1540, 595, 520, 1801, 134, 1254, 915, 1707, 1685, 2169, 1143, 849, 590, 1895, 964, 1287, 1973, 2196, 261, 344, 1127, 1122, 2064, 1927, 578, 392, 449, 1494, 283, 767, 1620, 1829, 1445, 2406, 1719, 1777, 900, 2210, 2098, 2185, 1029, 1289, 2075, 2370, 1963, 1511, 1863, 113, 657, 2347, 622, 1794, 562, 555, 2142, 1501, 855, 1367, 949, 2061, 2408, 2459, 1364, 2263, 1385, 1355, 2348, 1948, 1342, 2422, 2168, 2208, 35, 536, 762, 1137, 2290, 2058, 2311, 1374, 1493, 1639, 175, 2326, 455, 9, 61, 585, 1409, 1075, 1136, 1353, 2386, 2204, 1250, 279, 1937, 1417, 1940, 715, 1553, 2484, 1752, 552, 1730, 1346, 1753, 896, 1691, 194, 1986, 613, 1851, 1754, 732, 2016, 2067, 817, 1347, 22, 1884, 1069, 2322, 597, 679, 2076, 531, 348, 54, 267, 565, 2478, 275, 1545, 232, 1302, 1443, 2417, 711, 2092, 330, 1100, 1147, 1944, 2128, 1461, 1974, 394, 1388, 1873, 529, 589, 2136, 2321, 2205, 465, 705, 1908, 1452, 611, 1741, 488, 1247, 2480, 958, 1431, 2362, 1845, 424, 932, 1427, 405, 2314, 1132, 2153, 2504, 1199, 264, 1786, 2441, 1101, 640, 1542, 1398, 2419, 862, 1187, 241, 567, 2088, 234, 2216, 227, 680, 2134, 2073, 1582, 2512, 1576, 925, 2157, 1905, 890, 754, 1879, 818, 325, 1294, 1640, 1415, 427, 74, 480, 1997, 391, 95, 2316, 880, 2438, 821, 1, 1678, 1278, 858, 339, 417, 1018, 1798, 1429, 1947, 1444, 2106, 1128, 1672, 250, 1610, 2187, 2498, 1402, 368, 193, 2373, 1131, 72, 897, 2261, 482, 210, 2031, 87, 1223, 768, 1980, 1811, 1733, 1533, 2410, 218, 974, 1788, 848, 730, 369, 1601, 2011, 1926, 2297, 1917, 204, 505, 839, 154, 2120, 125, 772, 2104, 1898, 946, 789, 1306, 1571, 669, 497, 1530, 1340, 2014, 1103, 947, 496, 1473, 2350, 2415, 266, 2409, 475, 2277, 91, 2062, 557, 12, 1893, 2262, 1939, 639, 1327, 2139, 616, 2209, 1924, 2462, 429, 1369, 997, 734, 1195, 1225, 827, 2335, 511, 1261, 876, 2276, 2144, 477, 2337, 798, 1133, 636, 1985, 1589, 1343, 1505, 360, 857, 130, 760, 1061, 1562, 1384, 186, 2440, 1012, 1650, 1735, 452, 1359, 109, 2344, 1700, 788, 1766, 1167, 2396, 2500, 1025, 2049, 954, 888, 843, 510, 362, 1732, 2397, 1471, 2334, 944, 759, 867, 418, 346, 1762, 44, 162, 1060, 1105, 112, 1748, 956, 1299, 1867, 147, 837, 2113, 953, 1783, 1570, 2451, 122, 1850, 1351, 903, 1197, 326, 965, 2006, 2101, 198, 412, 1747, 690, 1259, 1086, 1938, 2063, 212, 1995, 384, 2508, 2143, 1715, 1770, 2211, 1217, 689, 1758, 1791, 1410, 2429, 845, 1961, 1531, 1380, 725, 206, 1011, 560, 2027, 971, 2477, 1023, 1669, 973, 2503, 1408, 1336, 1056, 835, 936, 1042, 603, 1904, 2312, 464, 1220, 1983, 387, 2435, 2307, 574, 2469, 1551, 1681, 1964, 1510, 1043, 1201, 1713, 1110, 1358, 2059, 633, 1512, 494, 2490, 2514, 2195, 244, 2070, 2367, 1246, 2333, 106, 507, 1945, 1883, 2355, 132, 1059, 825, 2293, 1418, 2309, 2035, 1825, 1102, 674, 1993, 1421, 2299, 127, 1135, 1607, 403, 484, 1580, 628, 1286, 2226, 1799, 898, 1275, 178, 1076, 1198, 651, 1956, 207, 1666, 886, 43, 425, 1959, 1564, 370, 1613, 1420, 379, 1492, 1694, 2444, 1224, 1982, 1024, 208, 2033, 2224, 300, 2090, 1479, 728, 548, 1328, 324, 852, 375, 1381, 961, 941, 1483, 1365, 64, 2222, 583, 967, 100, 1529, 46, 812, 86, 1793, 1731, 357, 1231, 901, 1305, 1729, 254, 1568, 1488, 1874, 1077, 152, 138, 481, 2034, 1742, 173, 238, 2115, 308, 2050, 31, 1810, 703, 708, 1386, 65, 1097, 883, 103, 2161, 2256, 1002, 143, 1827, 831, 272, 1695, 662, 1930, 2233, 1300, 225, 2305, 533, 775, 2392, 1622, 1558, 71, 278, 2191, 621, 779, 226, 2523, 713, 2308, 1536, 1901, 1950, 282, 10, 2074, 1523, 1708, 1209, 48, 2028, 398, 1918, 1876, 2199, 2065, 577, 1387, 1440, 2491, 1921, 1481, 800, 1455, 1780, 1248, 627, 446, 237, 331, 860, 1528, 726, 1126, 2439, 1965, 1098, 1430, 1795, 2376, 1711, 571, 624, 1192, 1159, 1013, 401, 2328, 320, 801, 2194, 2525, 682, 483, 1189, 1363, 606, 1082, 460, 1232, 1555, 1806, 42, 1045, 2285, 2264, 1265, 2496, 2145, 2251, 1055, 1977, 2354, 1352, 2403, 290, 274, 938, 1800, 1865, 1897, 1319, 1603, 1907, 1242, 105, 1664, 1739, 2221, 591, 474, 1591, 1092, 2246, 2237, 2252, 1303, 16, 998, 1490, 712, 287, 1856, 1522, 1820, 1674, 1270, 602, 1857, 1984, 2003, 1584, 165, 1913, 1751, 1524, 454, 410, 2, 1969, 235, 727, 28, 749, 1036, 1257, 2395, 980, 1285, 485, 1988, 349, 447, 1578, 430, 750, 1081, 2201, 486, 1676, 1717, 1712, 2232, 1684, 458, 224, 737, 2452, 183, 791, 612, 748, 1526, 909, 329, 2319, 771, 1317, 402, 160, 2486, 523, 2359, 641, 1771, 184, 2124, 1411, 2275, 291, 26, 1019, 869, 1016, 1439, 80, 1366, 2214, 1554, 1744, 1714, 1150, 509, 1572, 1451, 940, 2114, 1474, 1172, 2371, 707, 1478, 2137, 1301, 2492, 2024, 692, 1438, 1047, 764, 1051, 686, 1208, 912, 598, 1919, 2301, 2461, 322, 63, 2056, 1826, 323, 928, 770, 377, 2380, 2021, 115, 2138, 371, 939, 644, 97, 808, 1778, 1049, 1469, 1585, 1818, 1258, 406, 1718, 378, 1477, 1792, 409, 333, 1130, 1617, 2150, 600, 2352, 2008, 572, 648, 1096, 619, 1734, 763, 1394, 1652, 1344, 2190, 336, 219, 943, 191, 1268, 976, 90, 2174, 756, 540, 681, 2280, 1872, 1472, 2412, 587, 1764, 2215, 2160, 2332, 2163, 1882, 1124, 647, 1981, 561, 1157, 1637, 2053, 1658, 917, 795, 2494, 1763, 2178, 790, 744, 1737, 2193, 519, 1322, 2066, 457, 1608, 1099, 573, 1228, 2002, 411, 582, 1706, 314, 288, 1185, 2102, 435, 1573, 2206, 1740, 1979, 81, 1875, 528, 1338, 200, 637, 1070, 1541, 675, 214, 2288, 1008, 2304, 2471, 302, 59, 2284, 978, 2358, 546, 1245, 739, 124, 2180, 2382, 1323, 1177, 1227, 987, 381, 2378, 2170, 169, 1165, 2464, 29, 1425, 969, 1458, 981, 781, 1828, 670, 672, 2442, 2038, 321, 1214, 295, 626, 2040, 1736, 2254, 1350, 1281, 2383, 168, 82, 676, 2121, 2388, 1835, 1688, 436, 813, 2488, 121, 382, 189, 1928, 2231, 1489, 538, 803, 2404, 1015, 2147, 2037, 1906, 1210, 607, 153, 802, 2313, 902, 351, 1629, 1482, 766, 317, 52, 761, 563, 2122, 472, 2244, 1307, 196, 2418, 2126, 836, 1955, 2377, 541, 463, 814, 950, 2278, 1428, 1989, 525, 1952, 2266, 1814, 1065, 1414, 995, 2166, 123, 438, 1808, 2176, 1701, 1181, 1038, 1230, 863, 2228, 2361, 2487, 2019, 1824, 1117, 1868, 328, 1583, 850, 2387, 1140, 1665, 466, 1309, 1345, 49, 491, 1203, 2428, 299, 1966, 841, 2302, 1623, 1171, 751, 1403, 1017, 213, 1567, 1592, 1761, 1491, 2282, 2423, 128, 2449, 2227, 1916, 1720, 2202, 1390, 1506, 1765, 1237, 1176, 1709, 2032, 1377, 1375, 968, 844, 1166, 2436, 2081, 2247, 1841, 1527, 584, 2336, 1925, 534, 2133, 1262, 468, 1391, 1900, 41, 2421, 667, 629, 608, 2109, 1934, 197, 593, 1447, 661, 170, 1329, 265, 1161, 2029, 2398, 444, 1326, 185, 873, 84, 343, 469, 2470, 921, 832, 1682, 1007, 2353, 910, 116, 2453, 2346, 2069, 1063, 327, 228, 1649, 2240, 2030, 451, 960, 1341, 668, 501, 407, 688, 1957, 1769, 2082, 161, 2172, 301, 2223, 1692, 774, 434, 1655, 1164, 383, 2516, 765, 292, 1325, 549, 2369, 233, 2497, 2186, 1175, 145, 1267, 2167, 396, 558, 1318, 783, 1378, 1843, 1470, 605, 780, 2493, 1499, 1836, 2085, 2414, 1815, 1080, 816, 1005, 211, 442, 2048, 717, 1085, 1354, 316, 1216, 1755, 1566, 787, 1200, 2236, 77, 1048, 1256, 2463, 443, 2175, 2238, 2182, 1645, 655, 1971, 842, 521, 1293, 879, 437, 718, 1146, 2447, 1679, 677, 2473, 285, 2385, 350, 1112, 1839, 913, 1079, 2430, 2300, 846, 2125, 366, 1886, 2158, 1809, 695, 1393, 2141, 2476, 17, 1292, 2366, 1812, 1174, 2005, 601, 489, 386, 2087, 2330, 828, 1509, 1538, 853, 1633, 2116, 1044, 2217, 1503, 23, 1661, 2036, 1842, 1179, 1581, 2431, 268, 2255, 1213, 1628, 420, 714, 7, 678, 1113, 53, 1155, 535, 2489, 1284, 1767, 1139, 1496, 1946, 1787, 1244, 1507, 1274, 376, 1315, 502, 1446, 2389, 1817, 1321, 1311, 284, 1115, 2460, 1207, 1951, 1485, 840, 1441, 1361, 2315, 66, 642, 1837, 1896, 2306, 650, 2200, 966, 1062, 158, 359, 248, 599, 20, 588, 586, 2230, 2472, 1497, 492, 1449, 1579, 1687, 2164, 1943, 390, 111, 1830, 1892, 499, 1111, 1616, 209, 2427, 96, 500, 527, 1196, 1726, 1495, 1357, 1263, 2479, 709, 1330, 875, 1680, 2400, 769, 490, 1807, 2329, 785, 395, 1412, 236, 1859, 1031, 916, 1656, 517, 1885, 167, 2271, 2023, 570, 2235, 1502, 215, 664, 784, 1998, 1781, 671, 243, 1145, 24, 1205, 1372, 1463, 1296, 918, 239, 2045, 2099, 461, 1260, 32, 2374, 1424, 462, 1240, 1662, 2450, 1834, 1933, 2349, 1849, 2375, 1862, 2110, 1690, 2207, 2283, 1785, 895, 1399, 2119, 1621, 110, 393, 2351, 1226, 337, 1609, 1697, 1890, 777, 1500, 643, 1677, 2339, 1033, 1519, 400, 1006, 2004, 1550, 163, 2015, 1253, 870, 1288, 347, 2390, 389, 1686, 823, 2357, 1188, 514, 543, 623, 1183, 1990, 1615, 1349, 1949, 1821, 171, 990, 240, 982, 1516, 2020, 746, 658, 1823, 2343, 881, 2467, 281, 1881, 2117, 89, 740, 56, 666, 1560, 512, 945, 654, 1028, 2474, 1749, 2068, 1149, 1743, 2295, 2105, 2146, 724, 159, 2010, 2481, 2401, 1072, 1912, 1021, 1310, 2279, 2424, 8, 2524, 988, 177, 1480, 1556, 618, 1486, 1093, 51, 1184, 1071, 2242, 293, 1903, 1660, 280, 1182, 478, 2103, 1401, 2413, 217, 221, 2475, 1594, 38, 2179, 2265, 547, 752, 721, 397, 882, 2522, 2327, 2513, 149, 930, 166, 1255, 834, 1657, 866, 1169, 615, 820, 1084, 594, 247, 1635, 319, 1273, 223, 824, 1611, 67, 1590, 414, 568, 2189, 1659, 738, 905, 673, 1151, 2274, 199, 1855, 202, 1074, 256, 1156, 991, 868, 1339, 2043, 1915, 887, 1716, 2272, 2129, 270, 2188, 141, 1618, 962, 1702, 19, 294, 69, 604, 1211, 144, 229, 1373, 2509, 39, 1848, 441, 340, 385, 1604, 2426, 545, 1565, 1954, 819, 504, 1383, 2407, 1727, 757, 2192, 539, 60, 1587, 984, 45, 878, 747, 456, 691, 1064, 2483, 2249, 259, 1976, 2093, 1923, 1279, 1598, 2268, 1368, 1436, 2198, 806, 1148, 1854, 1994, 922, 459, 1337, 1593, 479, 872, 2506, 1356, 1805, 137, 2368, 2159, 1750, 1561, 286, 2108, 1632, 1710, 1141, 2181, 1911, 513, 1091, 1389, 963, 1334, 1154, 1000, 526, 2457, 2012, 1759, 503, 2123, 1370, 2465, 181, 2042, 1467, 1034, 1667, 190, 1416, 904, 1574, 1050, 2044, 258, 338, 1847, 743, 1107, 1252, 117, 2013, 2248, 1456, 2218, 1920, 1442, 148, 1612, 353, 1866, 1634, 1040, 433, 245, 70, 2432, 1266, 693, 1003, 2325, 399, 2079, 1308, 1001, 1577, 231, 156, 131, 179, 305, 1673, 2317, 34, 1125, 1779, 865, 793, 2155, 1803, 1625, 2055, 2320, 432, 135, 854, 931, 1802, 1173, 1382, 76, 1852, 2000, 1520, 2345, 1600, 342, 733, 1654, 2289, 861, 1178, 416, 1104, 722, 2219, 1543, 246, 2364, 2007, 476, 75, 1180, 927, 1498, 352, 592, 1004, 423, 1348, 1052, 923, 15, 1877, 758, 1790, 741, 810, 126, 1935, 1087, 1768, 1405, 85, 1202, 1651, 1693, 884, 2229, 0, 1557, 2017, 729, 57, 21, 2057, 439, 1119, 2086, 230, 1297, 2291, 1535, 1283, 2243, 1776, 1547, 1698, 2112, 1291, 2520, 999, 1596, 2118, 638, 2054, 311, 815, 55, 2149, 645, 856, 1109, 1846, 1942, 1816, 1026, 136, 2443, 1671, 1756, 289, 273, 1129, 18, 118, 1784, 1838, 566, 1138, 2096, 1088, 1958, 1518, 428, 310, 970, 1454, 129, 2445, 2156, 1142, 2365, 755, 1331, 1120, 1549, 1967, 1595, 2260, 2298, 992, 807, 1376, 2259, 1894, 576, 1272, 2071, 649, 632, 1870, 1552, 2245, 2162, 146, 2039, 874, 1433, 1642, 1030, 2356, 1683, 1636, 361, 2095, 2485, 776, 723, 1614, 1239, 2286, 983, 1475, 1453, 1095, 220, 253, 1194, 646, 1760, 1170, 720, 1515, 2281, 1295, 1335, 1996, 1134, 2270, 303, 1775, 498, 2225, 380, 119, 33, 892, 2318, 345, 1158, 1094, 2466, 2502, 182, 701, 1010, 315, 805, 2526, 277, 847, 706, 911, 419, 1219, 929, 1929, 2127, 1508, 1746, 2510, 1465, 891, 575, 2331, 1466, 959, 773, 2094, 1627, 742, 404, 388, 1559, 1597, 1548, 2499, 58, 2009, 1696, 614, 1362, 1869, 1725, 1222, 778, 1932, 537, 635, 1757, 335, 1191, 1168, 1721, 1163, 1878, 2257, 1406, 553, 1899, 1058, 1039, 1014, 2287, 625, 471, 1090, 1833, 1371, 2151, 2292, 2132, 1423, 1641, 1970, 1212, 1796, 114, 2303, 2084, 1643, 155, 1054, 1624, 1699, 2310, 581, 663, 996, 1379, 1264, 1902, 2131, 1324, 341, 1435, 596, 2379, 448, 908, 1108, 899, 1035, 1504, 1313, 1396, 2519, 859, 2089, 2203, 2077, 1704, 885, 924, 745, 1269, 450, 2372, 296, 139, 1588, 2130, 1116, 2177, 2515, 948, 1864, 1464, 524, 1083, 1638, 893, 334, 826, 1121, 1941, 620, 4, 1606, 2041, 37, 1605, 907, 1280, 426, 2269, 493, 440, 1861, 1152, 13, 1936, 906, 796, 374, 579, 445, 313, 83, 710, 2399, 1251, 157, 914, 1880, 1804, 809, 2183, 1931, 564, 1193, 1046, 1053, 413, 2455, 920, 263, 1631, 630, 1022, 2148, 1241, 985, 515, 1333, 609, 2173, 1236, 2521, 1468, 298, 1487, 1320, 2448, 975, 2080, 1304, 1020, 192, 2381, 365, 1271, 1276, 1797, 2025, 88, 1831, 1221, 1123, 174, 2420, 1772, 1037, 934, 685, 1238, 2060, 1218, 631, 222, 373, 1066, 140, 1992, 1630, 473, 2394, 68, 1723, 2507, 1413, 2405, 1891, 1888, 1360, 367, 955, 542, 271, 102, 1459, 786, 516, 2324, 1647, 700, 108, 151, 2341, 1670, 877, 1073, 2091, 993, 935, 792, 1539, 864, 518, 2250, 2505, 1975, 736, 704, 634, 2511, 1089, 2197, 951, 1009, 312, 838, 2072, 2253, 6, 422, 2363, 269, 255, 1204, 1844, 937, 716, 25, 201, 94, 697, 1644, 2267, 1646, 1233]

In [14]:
# Splitting the data into train, validation, and test set
train_idxs = idxs[:int(0.8 * len(idxs))]
valid_idxs = idxs[int(0.8 * len(idxs)):int(0.9 * len(idxs))]
test_idxs = idxs[int(0.9 * len(idxs)):]

In [15]:
train_dataset = dataset.select(train_idxs)
valid_dataset = dataset.select(valid_idxs)
test_dataset = dataset.select(test_idxs)

In [16]:
# Creating a dataset and saving it onto the laptop
datasets.DatasetDict(
    {
        "train": train_dataset,
        "valid": valid_dataset,
        "test": test_dataset,
    }
).save_to_disk("split_dataset")

Saving the dataset (0/3 shards):   0%|          | 0/2021 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/253 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/253 [00:00<?, ? examples/s]